Turn --stop into sync Ignore the rc of shutdowntags/v1.0.0
| @@ -23,11 +23,13 @@ namespace dataset { | |||
| PYBIND_REGISTER(CacheClient, 0, ([](const py::module *m) { | |||
| (void)py::class_<CacheClient, std::shared_ptr<CacheClient>>(*m, "CacheClient") | |||
| .def( | |||
| py::init([](session_id_type id, uint64_t mem_sz, bool spill, int32_t port, int32_t prefetch_sz) { | |||
| py::init([](session_id_type id, uint64_t mem_sz, bool spill, std::optional<std::string> hostname, | |||
| std::optional<int32_t> port, int32_t prefetch_sz) { | |||
| std::shared_ptr<CacheClient> cc; | |||
| CacheClient::Builder builder; | |||
| builder.SetSessionId(id).SetCacheMemSz(mem_sz).SetSpill(spill).SetPort(port).SetPrefetchSize( | |||
| prefetch_sz); | |||
| builder.SetSessionId(id).SetCacheMemSz(mem_sz).SetSpill(spill).SetPrefetchSize(prefetch_sz); | |||
| if (hostname) builder.SetHostname(hostname.value()); | |||
| if (port) builder.SetPort(port.value()); | |||
| THROW_IF_ERROR(builder.Build(&cc)); | |||
| return cc; | |||
| })) | |||
| @@ -19,10 +19,37 @@ | |||
| #include <iostream> | |||
| #include <string> | |||
| #include "mindspore/core/utils/log_adapter.h" | |||
| #include "minddata/dataset/util/system_pool.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| ConfigManager::ConfigManager() | |||
| : rows_per_buffer_(kCfgRowsPerBuffer), | |||
| num_parallel_workers_(kCfgParallelWorkers), | |||
| worker_connector_size_(kCfgWorkerConnectorSize), | |||
| op_connector_size_(kCfgOpConnectorSize), | |||
| seed_(kCfgDefaultSeed), | |||
| monitor_sampling_interval_(kCfgMonitorSamplingInterval), | |||
| callback_timout_(kCfgCallbackTimeout), | |||
| cache_host_(kCfgDefaultCacheHost), | |||
| cache_port_(kCfgDefaultCachePort) { | |||
| auto env_cache_host = std::getenv("MS_CACHE_HOST"); | |||
| auto env_cache_port = std::getenv("MS_CACHE_PORT"); | |||
| if (env_cache_host) { | |||
| cache_host_ = env_cache_host; | |||
| } | |||
| if (env_cache_port) { | |||
| char *end = nullptr; | |||
| cache_port_ = strtol(env_cache_port, &end, 10); | |||
| if (*end != '\0') { | |||
| MS_LOG(WARNING) << "\nCache port from env variable MS_CACHE_PORT is invalid, back to use default " | |||
| << kCfgDefaultCachePort << std::endl; | |||
| cache_port_ = kCfgDefaultCachePort; | |||
| } | |||
| } | |||
| } | |||
| // A print method typically used for debugging | |||
| void ConfigManager::Print(std::ostream &out) const { | |||
| // Don't show the test/internal ones. Only display the main ones here. | |||
| @@ -42,6 +69,8 @@ Status ConfigManager::FromJson(const nlohmann::json &j) { | |||
| set_op_connector_size(j.value("opConnectorSize", op_connector_size_)); | |||
| set_seed(j.value("seed", seed_)); | |||
| set_monitor_sampling_interval(j.value("monitorSamplingInterval", monitor_sampling_interval_)); | |||
| set_cache_host(j.value("cacheHost", cache_host_)); | |||
| set_cache_port(j.value("cachePort", cache_port_)); | |||
| return Status::OK(); | |||
| } | |||
| @@ -91,5 +120,8 @@ void ConfigManager::set_monitor_sampling_interval(uint32_t interval) { monitor_s | |||
| void ConfigManager::set_callback_timeout(uint32_t timeout) { callback_timout_ = timeout; } | |||
| void ConfigManager::set_cache_host(std::string cache_host) { cache_host_ = cache_host; } | |||
| void ConfigManager::set_cache_port(int32_t cache_port) { cache_port_ = cache_port; } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -41,7 +41,7 @@ namespace dataset { | |||
| // those values. | |||
| class ConfigManager { | |||
| public: | |||
| ConfigManager() = default; | |||
| ConfigManager(); | |||
| // destructor | |||
| ~ConfigManager() = default; | |||
| @@ -89,6 +89,14 @@ class ConfigManager { | |||
| // @return The internal worker-to-master connector queue size | |||
| int32_t worker_connector_size() const { return worker_connector_size_; } | |||
| // getter function | |||
| // @return The hostname of cache server | |||
| std::string cache_host() const { return cache_host_; } | |||
| // getter function | |||
| // @return The port of cache server | |||
| int32_t cache_port() const { return cache_port_; } | |||
| // setter function | |||
| // @param rows_per_buffer - The setting to apply to the config | |||
| void set_rows_per_buffer(int32_t rows_per_buffer); | |||
| @@ -105,6 +113,14 @@ class ConfigManager { | |||
| // @param connector_size - The setting to apply to the config | |||
| void set_op_connector_size(int32_t connector_size); | |||
| // setter function | |||
| // @param cache_host - The hostname of cache server | |||
| void set_cache_host(std::string cache_host); | |||
| // setter function | |||
| // @param cache_port - The port of cache server | |||
| void set_cache_port(int32_t cache_port); | |||
| uint32_t seed() const; | |||
| // setter function | |||
| @@ -128,13 +144,15 @@ class ConfigManager { | |||
| int32_t callback_timeout() const { return callback_timout_; } | |||
| private: | |||
| int32_t rows_per_buffer_{kCfgRowsPerBuffer}; | |||
| int32_t num_parallel_workers_{kCfgParallelWorkers}; | |||
| int32_t worker_connector_size_{kCfgWorkerConnectorSize}; | |||
| int32_t op_connector_size_{kCfgOpConnectorSize}; | |||
| uint32_t seed_{kCfgDefaultSeed}; | |||
| uint32_t monitor_sampling_interval_{kCfgMonitorSamplingInterval}; | |||
| uint32_t callback_timout_{kCfgCallbackTimeout}; | |||
| int32_t rows_per_buffer_; | |||
| int32_t num_parallel_workers_; | |||
| int32_t worker_connector_size_; | |||
| int32_t op_connector_size_; | |||
| uint32_t seed_; | |||
| uint32_t monitor_sampling_interval_; | |||
| uint32_t callback_timout_; | |||
| std::string cache_host_; | |||
| int32_t cache_port_; | |||
| // Private helper function that takes a nlohmann json format and populates the settings | |||
| // @param j - The json nlohmann json info | |||
| @@ -69,6 +69,8 @@ constexpr uint32_t kCfgOpConnectorSize = 16; | |||
| constexpr uint32_t kCfgDefaultSeed = std::mt19937::default_seed; | |||
| constexpr uint32_t kCfgMonitorSamplingInterval = 10; | |||
| constexpr uint32_t kCfgCallbackTimeout = 60; // timeout value for callback in seconds | |||
| constexpr int32_t kCfgDefaultCachePort = 50052; | |||
| constexpr char kCfgDefaultCacheHost[] = "127.0.0.1"; | |||
| // Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h) | |||
| constexpr uint8_t kCVInvalidType = 255; | |||
| @@ -25,21 +25,21 @@ | |||
| #include "minddata/dataset/engine/cache/cache_request.h" | |||
| #include "minddata/dataset/engine/cache/cache_client.h" | |||
| #include "minddata/dataset/util/path.h" | |||
| #include "minddata/dataset/core/constants.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| const char CacheAdminArgHandler::kDefaultHost[] = "127.0.0.1"; | |||
| const char CacheAdminArgHandler::kServerBinary[] = "cache_server"; | |||
| const char CacheAdminArgHandler::kDefaultSpillDir[] = "/tmp"; | |||
| CacheAdminArgHandler::CacheAdminArgHandler() | |||
| : port_(kDefaultPort), | |||
| : port_(kCfgDefaultCachePort), | |||
| session_id_(0), | |||
| num_workers_(kDefaultNumWorkers), | |||
| shm_mem_sz_(kDefaultSharedMemorySizeInGB), | |||
| log_level_(kDefaultLogLevel), | |||
| hostname_(kDefaultHost), | |||
| hostname_(kCfgDefaultCacheHost), | |||
| spill_dir_(kDefaultSpillDir), | |||
| command_id_(CommandId::kCmdUnknown) { | |||
| // Initialize the command mappings | |||
| @@ -376,6 +376,8 @@ Status CacheAdminArgHandler::StopServer() { | |||
| RETURN_IF_NOT_OK(comm.ServiceStart()); | |||
| auto rq = std::make_shared<ShutdownRequest>(); | |||
| RETURN_IF_NOT_OK(comm.HandleRequest(rq)); | |||
| // We will ignore the rc because if the shutdown is successful, the server will not reply back. | |||
| (void)rq->Wait(); | |||
| return Status::OK(); | |||
| } | |||
| @@ -29,11 +29,9 @@ namespace dataset { | |||
| class CacheAdminArgHandler { | |||
| public: | |||
| static constexpr int32_t kDefaultPort = 50052; | |||
| static constexpr int32_t kDefaultNumWorkers = 32; | |||
| static constexpr int32_t kDefaultSharedMemorySizeInGB = 4; | |||
| static constexpr int32_t kDefaultLogLevel = 1; | |||
| static const char kDefaultHost[]; | |||
| static const char kServerBinary[]; | |||
| static const char kDefaultSpillDir[]; | |||
| @@ -23,6 +23,35 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| CacheClient::Builder::Builder() | |||
| : session_id_(0), cache_mem_sz_(0), spill_(false), hostname_(""), port_(0), num_workers_(0), prefetch_size_(0) { | |||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||
| hostname_ = cfg->cache_host(); | |||
| port_ = cfg->cache_port(); | |||
| num_workers_ = cfg->num_parallel_workers(); | |||
| prefetch_size_ = 20; // rows_per_buf is too small (1 by default). | |||
| } | |||
| Status CacheClient::Builder::Build(std::shared_ptr<CacheClient> *out) { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| RETURN_IF_NOT_OK(SanityCheck()); | |||
| *out = | |||
| std::make_shared<CacheClient>(session_id_, cache_mem_sz_, spill_, hostname_, port_, num_workers_, prefetch_size_); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheClient::Builder::SanityCheck() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(session_id_ > 0, "session id must be positive"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(cache_mem_sz_ >= 0, "cache memory size must not be negative. (0 implies unlimited"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_workers_ > 0, "rpc workers must be positive"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(prefetch_size_ > 0, "prefetch size must be positive"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!hostname_.empty(), "hostname must not be empty"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(port_ > 0, "port must be positive"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(port_ <= 65535, "illegal port number"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(hostname_ == "127.0.0.1", | |||
| "now cache client has to be on the same host with cache server"); | |||
| return Status::OK(); | |||
| } | |||
| // Constructor | |||
| CacheClient::CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool spill, std::string hostname, | |||
| @@ -44,13 +44,7 @@ class CacheClient { | |||
| /// \brief A builder to help creating a CacheClient object | |||
| class Builder { | |||
| public: | |||
| Builder() : session_id_(0), cache_mem_sz_(0), spill_(false), port_(0), num_workers_(0), prefetch_size_(0) { | |||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||
| hostname_ = "127.0.0.1"; | |||
| port_ = 50052; | |||
| num_workers_ = cfg->num_parallel_workers(); | |||
| prefetch_size_ = 20; // rows_per_buf is too small (1 by default). | |||
| } | |||
| Builder(); | |||
| /// Setter function to set the session id | |||
| /// \param session_id | |||
| @@ -117,22 +111,9 @@ class CacheClient { | |||
| int32_t getNumWorkers() const { return num_workers_; } | |||
| int32_t getPrefetchSize() const { return prefetch_size_; } | |||
| Status SanityCheck() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(session_id_ > 0, "session id must be positive"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(cache_mem_sz_ >= 0, "cache memory size must not be negative. (0 implies unlimited"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_workers_ > 0, "rpc workers must be positive"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(prefetch_size_ > 0, "prefetch size must be positive"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!hostname_.empty(), "hostname must not be empty"); | |||
| return Status::OK(); | |||
| } | |||
| Status SanityCheck(); | |||
| Status Build(std::shared_ptr<CacheClient> *out) { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| RETURN_IF_NOT_OK(SanityCheck()); | |||
| *out = std::make_shared<CacheClient>(session_id_, cache_mem_sz_, spill_, hostname_, port_, num_workers_, | |||
| prefetch_size_); | |||
| return Status::OK(); | |||
| } | |||
| Status Build(std::shared_ptr<CacheClient> *out); | |||
| private: | |||
| session_id_type session_id_; | |||
| @@ -20,24 +20,25 @@ from mindspore._c_dataengine import CacheClient | |||
| from ..core.validator_helpers import type_check, check_uint32, check_uint64 | |||
| class DatasetCache: | |||
| """ | |||
| A client to interface with tensor caching service | |||
| """ | |||
| def __init__(self, session_id=None, size=0, spilling=False, port=50052, prefetch_size=20): | |||
| def __init__(self, session_id=None, size=0, spilling=False, hostname=None, port=None, prefetch_size=20): | |||
| check_uint32(session_id, "session_id") | |||
| check_uint64(size, "size") | |||
| type_check(spilling, (bool,), "spilling") | |||
| check_uint32(port, "port") | |||
| check_uint32(prefetch_size, "prefetch size") | |||
| self.session_id = session_id | |||
| self.size = size | |||
| self.spilling = spilling | |||
| self.hostname = hostname | |||
| self.port = port | |||
| self.prefetch_size = prefetch_size | |||
| self.cache_client = CacheClient(session_id, size, spilling, port, prefetch_size) | |||
| self.cache_client = CacheClient(session_id, size, spilling, hostname, port, prefetch_size) | |||
| def GetStat(self): | |||
| return self.cache_client.GetStat() | |||
| @@ -51,6 +52,7 @@ class DatasetCache: | |||
| new_cache.session_id = copy.deepcopy(self.session_id, memodict) | |||
| new_cache.spilling = copy.deepcopy(self.spilling, memodict) | |||
| new_cache.size = copy.deepcopy(self.size, memodict) | |||
| new_cache.hostname = copy.deepcopy(self.hostname, memodict) | |||
| new_cache.port = copy.deepcopy(self.port, memodict) | |||
| new_cache.prefetch_size = copy.deepcopy(self.prefetch_size, memodict) | |||
| new_cache.cache_client = self.cache_client | |||