| @@ -24,6 +24,11 @@ if (ENABLE_TDTQUE) | |||||
| add_definitions(-D ENABLE_TDTQUE) | add_definitions(-D ENABLE_TDTQUE) | ||||
| message(STATUS "TDT queue is enabled") | message(STATUS "TDT queue is enabled") | ||||
| endif () | endif () | ||||
| if (MS_BUILD_GRPC) | |||||
| set (ENABLE_CACHE true) | |||||
| add_definitions(-D ENABLE_CACHE) | |||||
| message(STATUS "Cache is enabled") | |||||
| endif() | |||||
| # conde coverage | # conde coverage | ||||
| # option(ENABLE_COVERAGE "Enable code coverage report" OFF) | # option(ENABLE_COVERAGE "Enable code coverage report" OFF) | ||||
| @@ -47,10 +52,6 @@ include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include | |||||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ORIGIN:$ORIGIN/lib") | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ORIGIN:$ORIGIN/lib") | ||||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=default") | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=default") | ||||
| include_directories("${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache") | |||||
| set(MD_FLATBUFFER_OU "${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache") | |||||
| ms_build_flatbuffers("engine/cache/de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${MD_FLATBUFFER_OU}) | |||||
| ################## Include sub-modules ############################### | ################## Include sub-modules ############################### | ||||
| add_subdirectory(util) | add_subdirectory(util) | ||||
| add_subdirectory(core) | add_subdirectory(core) | ||||
| @@ -70,8 +71,6 @@ add_dependencies(engine-datasetops-source-sampler core) | |||||
| add_dependencies(engine-datasetops core) | add_dependencies(engine-datasetops core) | ||||
| add_dependencies(engine-datasetops-mapop core) | add_dependencies(engine-datasetops-mapop core) | ||||
| add_dependencies(engine-opt core) | add_dependencies(engine-opt core) | ||||
| add_dependencies(engine-cache-client core) | |||||
| add_dependencies(engine-cache-server core) | |||||
| add_dependencies(engine-perf core) | add_dependencies(engine-perf core) | ||||
| add_dependencies(engine-gnn core) | add_dependencies(engine-gnn core) | ||||
| add_dependencies(engine core) | add_dependencies(engine core) | ||||
| @@ -85,7 +84,11 @@ endif() | |||||
| if (ENABLE_TDTQUE) | if (ENABLE_TDTQUE) | ||||
| add_dependencies(engine-tdt core) | add_dependencies(engine-tdt core) | ||||
| endif () | endif () | ||||
| if (ENABLE_CACHE) | |||||
| add_dependencies(engine-datasetops engine-cache-client) | |||||
| add_dependencies(engine-cache-client core) | |||||
| add_dependencies(engine-cache-server core) | |||||
| endif () | |||||
| ################### Create _c_dataengine Library ###################### | ################### Create _c_dataengine Library ###################### | ||||
| set(submodules | set(submodules | ||||
| $<TARGET_OBJECTS:core> | $<TARGET_OBJECTS:core> | ||||
| @@ -105,7 +108,6 @@ set(submodules | |||||
| $<TARGET_OBJECTS:engine-datasetops> | $<TARGET_OBJECTS:engine-datasetops> | ||||
| $<TARGET_OBJECTS:engine-opt> | $<TARGET_OBJECTS:engine-opt> | ||||
| $<TARGET_OBJECTS:engine-cache-client> | $<TARGET_OBJECTS:engine-cache-client> | ||||
| $<TARGET_OBJECTS:engine-cache-server> | |||||
| $<TARGET_OBJECTS:engine> | $<TARGET_OBJECTS:engine> | ||||
| $<TARGET_OBJECTS:text> | $<TARGET_OBJECTS:text> | ||||
| $<TARGET_OBJECTS:text-kernels> | $<TARGET_OBJECTS:text-kernels> | ||||
| @@ -123,8 +125,6 @@ else () | |||||
| add_library(_c_dataengine SHARED ${submodules}) | add_library(_c_dataengine SHARED ${submodules}) | ||||
| endif () | endif () | ||||
| add_dependencies(_c_dataengine generated_engine_files) | |||||
| if (ENABLE_PYTHON) | if (ENABLE_PYTHON) | ||||
| set_target_properties(_c_dataengine PROPERTIES | set_target_properties(_c_dataengine PROPERTIES | ||||
| PREFIX "${PYTHON_MODULE_PREFIX}" | PREFIX "${PYTHON_MODULE_PREFIX}" | ||||
| @@ -187,6 +187,6 @@ else() | |||||
| endif () | endif () | ||||
| endif() | endif() | ||||
| if (NOT CMAKE_SYSTEM_NAME MATCHES "Windows") | |||||
| if (MS_BUILD_GRPC) | |||||
| target_link_libraries(_c_dataengine PRIVATE mindspore::grpc++) | target_link_libraries(_c_dataengine PRIVATE mindspore::grpc++) | ||||
| endif() | |||||
| endif() | |||||
| @@ -22,7 +22,25 @@ namespace dataset { | |||||
| PYBIND_REGISTER(CacheClient, 0, ([](const py::module *m) { | PYBIND_REGISTER(CacheClient, 0, ([](const py::module *m) { | ||||
| (void)py::class_<CacheClient, std::shared_ptr<CacheClient>>(*m, "CacheClient") | (void)py::class_<CacheClient, std::shared_ptr<CacheClient>>(*m, "CacheClient") | ||||
| .def(py::init<uint32_t, uint64_t, bool>()); | |||||
| .def( | |||||
| py::init([](session_id_type id, uint64_t mem_sz, bool spill, int32_t port, int32_t prefetch_sz) { | |||||
| std::shared_ptr<CacheClient> cc; | |||||
| CacheClient::Builder builder; | |||||
| builder.SetSessionId(id).SetCacheMemSz(mem_sz).SetSpill(spill).SetPort(port).SetPrefetchSize( | |||||
| prefetch_sz); | |||||
| THROW_IF_ERROR(builder.Build(&cc)); | |||||
| return cc; | |||||
| })) | |||||
| .def("GetStat", [](CacheClient &cc) { | |||||
| CacheServiceStat stat{}; | |||||
| THROW_IF_ERROR(cc.GetStat(&stat)); | |||||
| return stat; | |||||
| }); | |||||
| (void)py::class_<CacheServiceStat>(*m, "CacheServiceStat") | |||||
| .def(py::init<>()) | |||||
| .def_readwrite("avg_cache_sz", &CacheServiceStat::avg_cache_sz) | |||||
| .def_readwrite("num_mem_cached", &CacheServiceStat::num_mem_cached) | |||||
| .def_readwrite("num_disk_cached", &CacheServiceStat::num_disk_cached); | |||||
| })); | })); | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -72,7 +72,8 @@ constexpr uint32_t kCfgMonitorSamplingInterval = 10; | |||||
| // Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h) | // Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h) | ||||
| constexpr uint8_t kCVInvalidType = 255; | constexpr uint8_t kCVInvalidType = 255; | ||||
| using connection_id_type = int64_t; | |||||
| using connection_id_type = uint64_t; | |||||
| using session_id_type = uint32_t; | |||||
| using row_id_type = int64_t; | using row_id_type = int64_t; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -20,10 +20,8 @@ if (ENABLE_PYTHON) | |||||
| target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS}) | target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS}) | ||||
| endif() | endif() | ||||
| add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf engine-cache-client engine-datasetops-mapop) | |||||
| if (ENABLE_TDTQUE) | if (ENABLE_TDTQUE) | ||||
| add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt engine-gnn engine-perf | |||||
| engine-cache-client engine-cache-server engine-datasetops-mapop) | |||||
| else () | |||||
| add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf | |||||
| engine-cache-client engine-cache-server engine-datasetops-mapop) | |||||
| add_dependencies(engine engine-tdt) | |||||
| endif () | endif () | ||||
| @@ -1,8 +1,47 @@ | |||||
| include_directories("${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache") | |||||
| set(MD_FLATBUFFER_OU "${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache") | |||||
| ms_build_flatbuffers("de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${MD_FLATBUFFER_OU}) | |||||
| file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | ||||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | ||||
| add_library(engine-cache-client OBJECT | add_library(engine-cache-client OBJECT | ||||
| cache_client.cc | cache_client.cc | ||||
| cache_fbb.cc | |||||
| cache_request.cc) | cache_request.cc) | ||||
| add_library(engine-cache-server OBJECT | |||||
| cache_service.cc | |||||
| cache_server.cc) | |||||
| if (ENABLE_CACHE) | |||||
| ms_grpc_generate(CACHE_GRPC_SRCS CACHE_GRPC_HDRS cache_grpc.proto) | |||||
| target_sources(engine-cache-client PUBLIC ${CACHE_GRPC_SRCS} cache_grpc_client.cc) | |||||
| add_library(engine-cache-server OBJECT | |||||
| ${CACHE_GRPC_SRCS} | |||||
| cache_grpc_server.cc | |||||
| cache_arena.cc | |||||
| cache_service.cc | |||||
| cache_server.cc) | |||||
| add_executable(cache_server cache_main.cc) | |||||
| target_link_libraries(cache_server | |||||
| engine-cache-server | |||||
| $<TARGET_OBJECTS:utils> | |||||
| mindspore | |||||
| mindspore::glog | |||||
| mindspore::protobuf | |||||
| mindspore::grpc++ | |||||
| mindspore_gvar | |||||
| ${PYTHON_LIBRARIES} | |||||
| ${SECUREC_LIBRARY} | |||||
| pthread) | |||||
| add_executable(cache_admin cache_admin.cc cache_admin_arg.cc) | |||||
| target_link_libraries(cache_admin _c_dataengine _c_mindrecord ${PYTHON_LIBRARIES} mindspore::glog) | |||||
| add_dependencies(engine-cache-server generated_engine_files) | |||||
| else () | |||||
| ms_protobuf_generate(CACHE_PROTO_SRCS CACHE_PRTO_HDRS cache_grpc.proto) | |||||
| target_sources(engine-cache-client PUBLIC ${CACHE_PROTO_SRCS}) | |||||
| endif () | |||||
| add_dependencies(engine-cache-client generated_engine_files) | |||||
| @@ -0,0 +1,70 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <unistd.h> | |||||
| #include <iostream> | |||||
| #ifdef USE_GLOG | |||||
| #include <glog/logging.h> | |||||
| #endif | |||||
| #include "minddata/dataset/engine/cache/cache_admin_arg.h" | |||||
| namespace ds = mindspore::dataset; | |||||
| int main(int argc, char **argv) { | |||||
| ds::Status rc; | |||||
| ds::CacheAdminArgHandler args; | |||||
| std::stringstream arg_stream; | |||||
| #ifdef USE_GLOG | |||||
| FLAGS_log_dir = "/tmp"; | |||||
| google::InitGoogleLogging(argv[0]); | |||||
| #endif | |||||
| std::string warningMsg; | |||||
| warningMsg.reserve(512); | |||||
| warningMsg += "WARNING:\n"; | |||||
| warningMsg += "cache_admin and the cache server that it controls are currently only used for experimental research"; | |||||
| warningMsg += " purposes at this time.\n"; | |||||
| warningMsg += "It is not intended for general availability yet as it may not be stable. Use it at your own risk.\n"; | |||||
| // A warning message until the code is mature enough. | |||||
| std::cerr << warningMsg << std::endl; | |||||
| if (argc == 1) { | |||||
| args.Help(); | |||||
| return 0; | |||||
| } | |||||
| // ingest all the args into a string stream for parsing | |||||
| for (int i = 1; i < argc; ++i) { | |||||
| arg_stream << " " << std::string(argv[i]); | |||||
| } | |||||
| // Parse the args | |||||
| rc = args.ParseArgStream(&arg_stream); | |||||
| if (!rc.IsOk()) { | |||||
| std::cerr << rc.ToString() << std::endl; | |||||
| return 1; | |||||
| } | |||||
| // Execute the command | |||||
| rc = args.RunCommand(); | |||||
| if (!rc.IsOk()) { | |||||
| std::cerr << rc.ToString() << std::endl; | |||||
| return 1; | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| @@ -0,0 +1,396 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/cache/cache_admin_arg.h" | |||||
| #include <sys/types.h> | |||||
| #include <sys/stat.h> | |||||
| #include <sys/wait.h> | |||||
| #include <unistd.h> | |||||
| #include <cerrno> | |||||
| #include <iostream> | |||||
| #include <string> | |||||
| #include <cstdlib> | |||||
| #include "minddata/dataset/engine/cache/cache_request.h" | |||||
| #include "minddata/dataset/engine/cache/cache_client.h" | |||||
| #include "minddata/dataset/util/path.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| const char CacheAdminArgHandler::kDefaultHost[] = "127.0.0.1"; | |||||
| const char CacheAdminArgHandler::kServerBinary[] = "cache_server"; | |||||
| const char CacheAdminArgHandler::kDefaultSpillDir[] = "/tmp"; | |||||
| CacheAdminArgHandler::CacheAdminArgHandler() | |||||
| : port_(kDefaultPort), | |||||
| session_id_(0), | |||||
| num_workers_(kDefaultNumWorkers), | |||||
| shm_mem_sz_(kDefaultSharedMemorySizeInGB), | |||||
| log_level_(kDefaultLogLevel), | |||||
| hostname_(kDefaultHost), | |||||
| spill_dir_(kDefaultSpillDir), | |||||
| command_id_(CommandId::kCmdUnknown) { | |||||
| // Initialize the command mappings | |||||
| arg_map_["-h"] = ArgValue::kArgHost; | |||||
| arg_map_["--hostname"] = ArgValue::kArgHost; | |||||
| arg_map_["-p"] = ArgValue::kArgPort; | |||||
| arg_map_["--port"] = ArgValue::kArgPort; | |||||
| arg_map_["--start"] = ArgValue::kArgStart; | |||||
| arg_map_["--stop"] = ArgValue::kArgStop; | |||||
| arg_map_["--help"] = ArgValue::kArgHelp; | |||||
| arg_map_["--generate_session"] = ArgValue::kArgGenerateSession; | |||||
| arg_map_["-g"] = ArgValue::kArgGenerateSession; | |||||
| arg_map_["--destroy_session"] = ArgValue::kArgDestroySession; | |||||
| arg_map_["-d"] = ArgValue::kArgDestroySession; | |||||
| arg_map_["--spilldir"] = ArgValue::kArgSpillDir; | |||||
| arg_map_["-s"] = ArgValue::kArgSpillDir; | |||||
| arg_map_["-w"] = ArgValue::kArgNumWorkers; | |||||
| arg_map_["--workers"] = ArgValue::kArgNumWorkers; | |||||
| arg_map_["-m"] = ArgValue::kArgSharedMemorySize; | |||||
| arg_map_["--shared_memory_size"] = ArgValue::kArgSharedMemorySize; | |||||
| arg_map_["-l"] = ArgValue::kArgLogLevel; | |||||
| arg_map_["--minloglevel"] = ArgValue::kArgLogLevel; | |||||
| // Initialize argument tracker with false values | |||||
| for (int16_t i = 0; i < static_cast<int16_t>(ArgValue::kArgNumArgs); ++i) { | |||||
| ArgValue currAV = static_cast<ArgValue>(i); | |||||
| used_args_[currAV] = false; | |||||
| } | |||||
| } | |||||
| Status CacheAdminArgHandler::AssignArg(std::string option, int32_t *out_arg, std::stringstream *arg_stream, | |||||
| CommandId command_id) { | |||||
| // Detect if the user tried to provide this argument more than once | |||||
| ArgValue selected_arg = arg_map_[option]; | |||||
| if (used_args_[selected_arg]) { | |||||
| std::string err_msg = "The " + option + " argument was given more than once."; | |||||
| return Status(StatusCode::kSyntaxError, err_msg); | |||||
| } | |||||
| // Flag that this arg is used now | |||||
| used_args_[selected_arg] = true; | |||||
| // Some options are just arguments, for example "--port 50052" is not a command, it's just a argument. | |||||
| // Other options are actual commands, for example "--destroy_session 1234". This executes the destroy session. | |||||
| // If this option is also a command, make sure there has not been multiple commands given before assigning it. | |||||
| if (command_id != CommandId::kCmdUnknown) { | |||||
| if (command_id_ != CommandId::kCmdUnknown) { | |||||
| std::string err_msg = "Only one command at a time is allowed. Invalid command: " + option; | |||||
| return Status(StatusCode::kSyntaxError, err_msg); | |||||
| } else { | |||||
| command_id_ = command_id; | |||||
| } | |||||
| } | |||||
| std::string value_as_string; | |||||
| // Fetch the argument from the arg stream into a string | |||||
| *arg_stream >> value_as_string; | |||||
| if (value_as_string.empty()) { | |||||
| std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>"; | |||||
| return Status(StatusCode::kSyntaxError, err_msg); | |||||
| } | |||||
| // Now, attempt to convert the value into it's string format for output | |||||
| try { | |||||
| *out_arg = std::stoul(value_as_string); | |||||
| } catch (const std::exception &e) { | |||||
| std::string err_msg = "Invalid numeric value: " + value_as_string; | |||||
| return Status(StatusCode::kSyntaxError, err_msg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheAdminArgHandler::AssignArg(std::string option, std::string *out_arg, std::stringstream *arg_stream, | |||||
| CommandId command_id) { | |||||
| // Detect if the user tried to provide this argument more than once | |||||
| ArgValue selected_arg = arg_map_[option]; | |||||
| if (used_args_[selected_arg]) { | |||||
| std::string err_msg = "The " + option + " argument was given more than once."; | |||||
| return Status(StatusCode::kSyntaxError, err_msg); | |||||
| } | |||||
| // Flag that this arg is used now | |||||
| used_args_[selected_arg] = true; | |||||
| // Some options are just arguments, for example "--hostname "127.0.0.1" is not a command, it's just an argument. | |||||
| // Other options are actual commands, for example "--start". | |||||
| // If this option is also a command, make sure there has not been multiple commands given before assigning it. | |||||
| if (command_id != CommandId::kCmdUnknown) { | |||||
| if (command_id_ != CommandId::kCmdUnknown) { | |||||
| std::string err_msg = "Only one command at a time is allowed. Invalid command: " + option; | |||||
| return Status(StatusCode::kSyntaxError, err_msg); | |||||
| } else { | |||||
| command_id_ = command_id; | |||||
| } | |||||
| } | |||||
| // If there is no argument to get, such as the --start command, then out_arg will be a nullptr. | |||||
| if (out_arg != nullptr) { | |||||
| // Fetch the argument from the arg stream into a string | |||||
| *arg_stream >> *out_arg; | |||||
| if (out_arg->empty()) { | |||||
| std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>"; | |||||
| return Status(StatusCode::kSyntaxError, err_msg); | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheAdminArgHandler::ParseArgStream(std::stringstream *arg_stream) { | |||||
| std::string tok; | |||||
| while (*arg_stream >> tok) { | |||||
| switch (arg_map_[tok]) { | |||||
| case ArgValue::kArgHost: { | |||||
| RETURN_IF_NOT_OK(AssignArg(tok, &hostname_, arg_stream)); | |||||
| break; | |||||
| } | |||||
| case ArgValue::kArgPort: { | |||||
| RETURN_IF_NOT_OK(AssignArg(tok, &port_, arg_stream)); | |||||
| break; | |||||
| } | |||||
| case ArgValue::kArgStart: { | |||||
| RETURN_IF_NOT_OK(AssignArg(tok, static_cast<std::string *>(nullptr), arg_stream, CommandId::kCmdStart)); | |||||
| break; | |||||
| } | |||||
| case ArgValue::kArgStop: { | |||||
| RETURN_IF_NOT_OK(AssignArg(tok, static_cast<std::string *>(nullptr), arg_stream, CommandId::kCmdStop)); | |||||
| break; | |||||
| } | |||||
| case ArgValue::kArgGenerateSession: { | |||||
| RETURN_IF_NOT_OK( | |||||
| AssignArg(tok, static_cast<std::string *>(nullptr), arg_stream, CommandId::kCmdGenerateSession)); | |||||
| break; | |||||
| } | |||||
| case ArgValue::kArgHelp: { | |||||
| command_id_ = CommandId::kCmdHelp; | |||||
| break; | |||||
| } | |||||
| case ArgValue::kArgDestroySession: { | |||||
| // session_id is an unsigned type. We may need to template the AssignArg function so that | |||||
| // it can handle different flavours of integers instead of just int32_t. | |||||
| int32_t session_int; | |||||
| RETURN_IF_NOT_OK(AssignArg(tok, &session_int, arg_stream, CommandId::kCmdDestroySession)); | |||||
| session_id_ = session_int; | |||||
| break; | |||||
| } | |||||
| case ArgValue::kArgNumWorkers: { | |||||
| RETURN_IF_NOT_OK(AssignArg(tok, &num_workers_, arg_stream)); | |||||
| break; | |||||
| } | |||||
| case ArgValue::kArgSpillDir: { | |||||
| RETURN_IF_NOT_OK(AssignArg(tok, &spill_dir_, arg_stream)); | |||||
| break; | |||||
| } | |||||
| case ArgValue::kArgSharedMemorySize: { | |||||
| RETURN_IF_NOT_OK(AssignArg(tok, &shm_mem_sz_, arg_stream)); | |||||
| break; | |||||
| } | |||||
| case ArgValue::kArgLogLevel: { | |||||
| RETURN_IF_NOT_OK(AssignArg(tok, &log_level_, arg_stream)); | |||||
| break; | |||||
| } | |||||
| default: { | |||||
| // Save space delimited trailing arguments | |||||
| trailing_args_ += (" " + tok); | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| RETURN_IF_NOT_OK(Validate()); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheAdminArgHandler::Validate() { | |||||
| // This sanity check is delayed until now in case there may be valid use-cases of trailing args. | |||||
| // Any unhandled arguments at this point is an error. | |||||
| if (!trailing_args_.empty()) { | |||||
| std::string err_msg = "Invalid arguments provided: " + trailing_args_; | |||||
| return Status(StatusCode::kSyntaxError, err_msg); | |||||
| } | |||||
| // The user must pick at least one command. i.e. it's meaningless to just give a hostname or port but no command to | |||||
| // run. | |||||
| if (command_id_ == CommandId::kCmdUnknown) { | |||||
| std::string err_msg = "No command provided"; | |||||
| return Status(StatusCode::kSyntaxError, err_msg); | |||||
| } | |||||
| // Additional checks here | |||||
| if (num_workers_ < 1) return Status(StatusCode::kSyntaxError, "Number of workers must be positive value."); | |||||
| if (log_level_ < 0 || log_level_ > 3) return Status(StatusCode::kSyntaxError, "Log level must be in range (0..3)."); | |||||
| // port range check? | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheAdminArgHandler::RunCommand() { | |||||
| switch (command_id_) { | |||||
| case CommandId::kCmdHelp: { | |||||
| Help(); | |||||
| break; | |||||
| } | |||||
| case CommandId::kCmdStart: { | |||||
| RETURN_IF_NOT_OK(StartServer()); | |||||
| break; | |||||
| } | |||||
| case CommandId::kCmdStop: { | |||||
| RETURN_IF_NOT_OK(StopServer()); | |||||
| break; | |||||
| } | |||||
| case CommandId::kCmdGenerateSession: { | |||||
| CacheClientGreeter comm(hostname_, port_, 1); | |||||
| RETURN_IF_NOT_OK(comm.ServiceStart()); | |||||
| auto rq = std::make_shared<GenerateSessionIdRequest>(); | |||||
| RETURN_IF_NOT_OK(comm.HandleRequest(rq)); | |||||
| RETURN_IF_NOT_OK(rq->Wait()); | |||||
| std::cout << rq->GetSessionId() << std::endl; | |||||
| break; | |||||
| } | |||||
| case CommandId::kCmdDestroySession: { | |||||
| CacheClientGreeter comm(hostname_, port_, 1); | |||||
| RETURN_IF_NOT_OK(comm.ServiceStart()); | |||||
| CacheClientInfo cinfo; | |||||
| cinfo.set_session_id(session_id_); | |||||
| auto rq = std::make_shared<DropSessionRequest>(cinfo); | |||||
| RETURN_IF_NOT_OK(comm.HandleRequest(rq)); | |||||
| RETURN_IF_NOT_OK(rq->Wait()); | |||||
| std::cout << "Drop session successful" << std::endl; | |||||
| break; | |||||
| } | |||||
| default: { | |||||
| RETURN_STATUS_UNEXPECTED("Invalid cache admin command id."); | |||||
| break; | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheAdminArgHandler::StartServer() { | |||||
| // There currently does not exist any "install path" or method to identify which path the installed binaries will | |||||
| // exist in. As a temporary approach, we will assume that the server binary shall exist in the same path as the | |||||
| // cache_admin binary (this process). | |||||
| const std::string self_proc = "/proc/self/exe"; | |||||
| std::string canonical_path; | |||||
| canonical_path.resize(400); // PATH_MAX is large. This value should be big enough for our use. | |||||
| // Some lower level OS library calls are needed here to determine the binary path. | |||||
| // Fetch the path of this binary for admin_cache into C character array and then truncate off the binary name so that | |||||
| // we are left with only the absolute path | |||||
| if (realpath(self_proc.data(), canonical_path.data()) == nullptr) { | |||||
| std::string err_msg = "Failed to identify cache admin binary path: " + std::to_string(errno); | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | |||||
| canonical_path.resize(strlen(canonical_path.data())); | |||||
| int last_seperator = canonical_path.find_last_of('/'); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(last_seperator != std::string::npos, "No / found"); | |||||
| // truncate the binary name so we are left with the absolute path of cache_admin binary | |||||
| canonical_path.resize(last_seperator + 1); | |||||
| std::string cache_server_binary = canonical_path + std::string(kServerBinary); | |||||
| // Create a pipe before we fork. If all goes well, the child will run as a daemon in the background | |||||
| // and never returns until shutdown. If there is any error, the child will notify us through the pipe. | |||||
| int fd[2]; | |||||
| if (pipe(fd) == -1) { | |||||
| std::string err_msg = "Failed to create a pipe for communication " + std::to_string(errno); | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | |||||
| // fork the child process to become the daemon | |||||
| pid_t pid; | |||||
| pid = fork(); | |||||
| // failed to fork | |||||
| if (pid < 0) { | |||||
| std::string err_msg = "Failed to fork process for cache server: " + std::to_string(errno); | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } else if (pid > 0) { | |||||
| // As a parent, we close the write end. We only listen. | |||||
| close(fd[1]); | |||||
| dup2(fd[0], 0); | |||||
| close(fd[0]); | |||||
| wait(nullptr); | |||||
| std::string msg; | |||||
| const int32_t buf_sz = 1024; | |||||
| msg.resize(buf_sz); | |||||
| auto n = read(0, msg.data(), buf_sz); | |||||
| if (n < 0) { | |||||
| std::string err_msg = "Failed to read from pipeline " + std::to_string(errno); | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | |||||
| msg.resize(n); | |||||
| std::cout << msg << std::endl; | |||||
| return Status::OK(); | |||||
| } else { | |||||
| // Child here ... | |||||
| // Close all stdin, redirect stdout and stderr to the write end of the pipe. | |||||
| close(fd[0]); | |||||
| dup2(fd[1], 1); | |||||
| dup2(fd[1], 2); | |||||
| close(0); | |||||
| close(fd[1]); | |||||
| // exec the cache server binary in this process | |||||
| std::string port_string = std::to_string(port_); | |||||
| std::string workers_string = std::to_string(num_workers_); | |||||
| std::string shared_memory_string = std::to_string(shm_mem_sz_); | |||||
| std::string minloglevel_string = std::to_string(log_level_); | |||||
| std::string daemonize_string = "true"; | |||||
| char *argv[8]; | |||||
| argv[0] = cache_server_binary.data(); // First arg is usually the binary name | |||||
| argv[1] = spill_dir_.data(); | |||||
| argv[2] = workers_string.data(); | |||||
| argv[3] = port_string.data(); | |||||
| argv[4] = shared_memory_string.data(); | |||||
| argv[5] = minloglevel_string.data(); | |||||
| argv[6] = daemonize_string.data(); | |||||
| argv[7] = nullptr; | |||||
| // Now exec the binary | |||||
| execv(argv[0], argv); | |||||
| // If the exec was successful, this line will never be reached due to process image being replaced. | |||||
| // ..unless exec failed. | |||||
| std::string err_msg = "Failed to exec cache server: " + cache_server_binary; | |||||
| std::cerr << err_msg << std::endl; | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | |||||
| } | |||||
| Status CacheAdminArgHandler::StopServer() { | |||||
| CacheClientGreeter comm(hostname_, port_, 1); | |||||
| RETURN_IF_NOT_OK(comm.ServiceStart()); | |||||
| auto rq = std::make_shared<ShutdownRequest>(); | |||||
| RETURN_IF_NOT_OK(comm.HandleRequest(rq)); | |||||
| return Status::OK(); | |||||
| } | |||||
| void CacheAdminArgHandler::Help() { | |||||
| std::cerr << "Syntax:\n"; | |||||
| std::cerr << " cache_admin [--start | --stop]\n"; | |||||
| std::cerr << " [ [-h | --hostname] <hostname> ]\n"; | |||||
| std::cerr << " [ [-p | --port] <port number> ]\n"; | |||||
| std::cerr << " [ [-g | --generate_session] ]\n"; | |||||
| std::cerr << " [ [-d | --destroy_session] <session id> ]\n"; | |||||
| std::cerr << " [ [-w | --workers] <number of workers> ]\n"; | |||||
| std::cerr << " [ [-s | --spilldir] <spilling directory> ]\n"; | |||||
| std::cerr << " [ [-m | --shared_memory_size] <shared memory size> ]\n"; | |||||
| std::cerr << " [ [-l | --minloglevel] <log level> ]\n"; | |||||
| std::cerr << " [--help]" << std::endl; | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,105 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ADMIN_ARG_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ADMIN_ARG_H_ | |||||
| #include <iostream> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <sstream> | |||||
| #include "minddata/dataset/util/status.h" | |||||
| #include "minddata/dataset/engine/cache/cache_client.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| class CacheAdminArgHandler { | |||||
| public: | |||||
| static constexpr int32_t kDefaultPort = 50052; | |||||
| static constexpr int32_t kDefaultNumWorkers = 32; | |||||
| static constexpr int32_t kDefaultSharedMemorySizeInGB = 4; | |||||
| static constexpr int32_t kDefaultLogLevel = 1; | |||||
| static const char kDefaultHost[]; | |||||
| static const char kServerBinary[]; | |||||
| static const char kDefaultSpillDir[]; | |||||
| // These are the actual command types to execute | |||||
| enum class CommandId : int16_t { | |||||
| kCmdHelp = 0, | |||||
| kCmdStart = 1, | |||||
| kCmdStop = 2, | |||||
| kCmdGenerateSession = 3, | |||||
| kCmdDestroySession = 4, | |||||
| kCmdUnknown = 32767 | |||||
| }; | |||||
| CacheAdminArgHandler(); | |||||
| ~CacheAdminArgHandler() = default; | |||||
| Status ParseArgStream(std::stringstream *arg_stream); | |||||
| Status RunCommand(); | |||||
| void Help(); | |||||
| private: | |||||
| // These are the supported argument string integer mappings | |||||
| enum class ArgValue : int16_t { | |||||
| kArgUnknown = 0, // Must be at position 0. invalid map lookups in arg_map_ default to value 0 | |||||
| kArgStart = 1, | |||||
| kArgStop = 2, | |||||
| kArgHost = 3, | |||||
| kArgPort = 4, | |||||
| kArgHelp = 5, | |||||
| kArgGenerateSession = 6, | |||||
| kArgDestroySession = 7, | |||||
| kArgSpillDir = 8, | |||||
| kArgNumWorkers = 9, | |||||
| kArgSharedMemorySize = 10, | |||||
| kArgLogLevel = 11, | |||||
| kArgNumArgs = 12 // Must be the last position to provide a count | |||||
| }; | |||||
| Status StartServer(); | |||||
| Status StopServer(); | |||||
| Status AssignArg(std::string option, int32_t *out_arg, std::stringstream *arg_stream, | |||||
| CommandId command_id = CommandId::kCmdUnknown); | |||||
| Status AssignArg(std::string option, std::string *out_arg, std::stringstream *arg_stream, | |||||
| CommandId command_id = CommandId::kCmdUnknown); | |||||
| Status Validate(); | |||||
| CommandId command_id_; | |||||
| int32_t port_; | |||||
| int32_t num_workers_; | |||||
| int32_t shm_mem_sz_; | |||||
| int32_t log_level_; | |||||
| session_id_type session_id_; | |||||
| std::string hostname_; | |||||
| std::string spill_dir_; | |||||
| std::string trailing_args_; | |||||
| std::map<std::string, ArgValue> arg_map_; | |||||
| std::map<ArgValue, bool> used_args_; | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ADMIN_ARG_H_ | |||||
| @@ -0,0 +1,73 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/cache/cache_arena.h" | |||||
| #include "minddata/dataset/util/path.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| CachedSharedMemoryArena::CachedSharedMemoryArena(int32_t port, size_t val_in_GB) | |||||
| : Arena::Arena(val_in_GB * 1024), port_(port), shmid_(-1) {} | |||||
| CachedSharedMemoryArena::~CachedSharedMemoryArena() { | |||||
| #if CACHE_LOCAL_CLIENT | |||||
| if (this->ptr_ != nullptr && this->ptr_ != reinterpret_cast<void *>(-1)) { | |||||
| shmdt(this->ptr_); | |||||
| } | |||||
| this->ptr_ = nullptr; | |||||
| if (shmid_ != -1) { | |||||
| shmctl(shmid_, IPC_RMID, nullptr); | |||||
| // Also remove the path we use to generate ftok. | |||||
| Path p(PortToUnixSocketPath(port_)); | |||||
| (void)p.Remove(); | |||||
| } | |||||
| #endif | |||||
| } | |||||
| Status CachedSharedMemoryArena::CreateArena(std::unique_ptr<CachedSharedMemoryArena> *out, int32_t port, | |||||
| size_t val_in_GB) { | |||||
| RETURN_UNEXPECTED_IF_NULL(out); | |||||
| #if CACHE_LOCAL_CLIENT | |||||
| auto ba = new (std::nothrow) CachedSharedMemoryArena(port, val_in_GB); | |||||
| if (ba == nullptr) { | |||||
| return Status(StatusCode::kOutOfMemory); | |||||
| } | |||||
| // Transfer the ownership of this pointer. Any future error in the processing we will have | |||||
| // the destructor of *out to deal. | |||||
| (*out).reset(ba); | |||||
| // Generate the ftok using a combination of port. | |||||
| int err; | |||||
| auto shm_key = PortToFtok(port, &err); | |||||
| if (shm_key == (key_t)-1) { | |||||
| std::string errMsg = "Ftok failed with errno " + std::to_string(err); | |||||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||||
| } | |||||
| auto access_mode = S_IRUSR | S_IWUSR | S_IROTH | S_IWOTH | S_IRGRP | S_IWGRP; | |||||
| ba->shmid_ = shmget(shm_key, ba->size_in_bytes_, IPC_CREAT | IPC_EXCL | access_mode); | |||||
| if (ba->shmid_) { | |||||
| ba->ptr_ = shmat(ba->shmid_, nullptr, 0); | |||||
| if (ba->ptr_ == reinterpret_cast<void *>(-1)) { | |||||
| RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno)); | |||||
| } | |||||
| } else { | |||||
| RETURN_STATUS_UNEXPECTED("Shared memory creation failed. Errno " + std::to_string(errno)); | |||||
| } | |||||
| uint64_t num_blks = ba->size_in_bytes_ / ARENA_BLK_SZ; | |||||
| MS_LOG(DEBUG) << "Size of memory pool is " << num_blks << ", number of blocks of size is " << ARENA_BLK_SZ << "."; | |||||
| ba->tr_.Insert(0, num_blks); | |||||
| #endif | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,52 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ARENA_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ARENA_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "minddata/dataset/util/arena.h" | |||||
| #include "minddata/dataset/engine/cache/cache_common.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| /// This is a derived class of Arena but resides in shared memory | |||||
| class CachedSharedMemoryArena : public Arena { | |||||
| public: | |||||
| ~CachedSharedMemoryArena() override; | |||||
| /// \brief Create an Arena in shared memory | |||||
| /// \param[out] p_ba Pointer to a unique_ptr | |||||
| /// \param shmkey Shared memory key | |||||
| /// \param val_in_GB size of shared memory in gigabyte | |||||
| /// \return Status object | |||||
| static Status CreateArena(std::unique_ptr<CachedSharedMemoryArena> *out, int32_t port, size_t val_in_GB); | |||||
| /// \brief This returns where we attach to the shared memory. | |||||
| /// Some gRPC requests will ask for a shared memory block, and | |||||
| /// we can't return the absolute address as this makes no sense | |||||
| /// in the client. So instead we will return an address relative | |||||
| /// to the base address of the shared memory where we attach to. | |||||
| /// \return Base address of the shared memory. | |||||
| const void *SharedMemoryBaseAddr() const { return this->ptr_; } | |||||
| private: | |||||
| int32_t port_; | |||||
| int shmid_; | |||||
| /// Private constructor. Not to be called directly. | |||||
| CachedSharedMemoryArena(int32_t port, size_t val_in_GB); | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ARENA_H_ | |||||
| @@ -17,29 +17,45 @@ | |||||
| #include <iomanip> | #include <iomanip> | ||||
| #include "minddata/dataset/engine/cache/cache_client.h" | #include "minddata/dataset/engine/cache/cache_client.h" | ||||
| #include "minddata/dataset/engine/cache/cache_request.h" | #include "minddata/dataset/engine/cache/cache_request.h" | ||||
| #include "minddata/dataset/engine/cache/cache_service.h" | |||||
| #include "minddata/dataset/engine/cache/cache_fbb.h" | |||||
| #include "minddata/dataset/util/bit.h" | #include "minddata/dataset/util/bit.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| // Constructor | // Constructor | ||||
| CacheClient::CacheClient(uint32_t session_id, uint64_t cache_mem_sz, bool spill) | |||||
| : server_connection_id_(0), session_id_(session_id), cache_crc_(0), cache_mem_sz_(cache_mem_sz), spill_(spill) {} | |||||
| CacheClient::CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool spill, std::string hostname, | |||||
| int32_t port, int32_t num_workers, int32_t prefetch_size) | |||||
| : server_connection_id_(0), | |||||
| cache_mem_sz_(cache_mem_sz), | |||||
| spill_(spill), | |||||
| local_bypass_(false), | |||||
| hostname_(std::move(hostname)), | |||||
| port_(port), | |||||
| num_workers_(num_workers), | |||||
| prefetch_size_(prefetch_size) { | |||||
| cinfo_.set_session_id(session_id); | |||||
| comm_ = std::make_shared<CacheClientGreeter>(hostname_, port_, num_workers_); | |||||
| } | |||||
| // print method for display cache details | // print method for display cache details | ||||
| void CacheClient::Print(std::ostream &out) const { | void CacheClient::Print(std::ostream &out) const { | ||||
| out << " Session id: " << session_id_ << "\n Cache crc: " << cache_crc_ | |||||
| << "\n Server cache id: " << server_connection_id_ << "\n Cache mem size: " << cache_mem_sz_ | |||||
| << "\n Spilling: " << std::boolalpha << spill_; | |||||
| out << " Session id: " << session_id() << "\n Cache crc: " << cinfo_.crc() | |||||
| << "\n Server cache id: " << server_connection_id_ << "\n Cache mem size: " << getCacheMemSz() | |||||
| << "\n Spilling: " << std::boolalpha << isSpill() << "\n Hostname: " << getHostname() | |||||
| << "\n Port: " << getPort() << "\n Number of rpc workers: " << getNumWorkers() | |||||
| << "\n Prefetch size: " << getPrefetchSize() << "\n Local client support: " << std::boolalpha | |||||
| << SupportLocalClient(); | |||||
| } | } | ||||
| Status CacheClient::WriteRow(const TensorRow &row, row_id_type *row_id_from_server) const { | Status CacheClient::WriteRow(const TensorRow &row, row_id_type *row_id_from_server) const { | ||||
| CacheRowRequest rq(server_connection_id_, cookie()); | |||||
| RETURN_IF_NOT_OK(rq.SerializeCacheRowRequest(row)); | |||||
| RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); | |||||
| RETURN_IF_NOT_OK(rq.Wait()); | |||||
| auto rq = std::make_shared<CacheRowRequest>(server_connection_id_, cookie(), SupportLocalClient()); | |||||
| RETURN_IF_NOT_OK(rq->SerializeCacheRowRequest(this, row)); | |||||
| RETURN_IF_NOT_OK(PushRequest(rq)); | |||||
| RETURN_IF_NOT_OK(rq->Wait()); | |||||
| if (row_id_from_server != nullptr) { | if (row_id_from_server != nullptr) { | ||||
| *row_id_from_server = rq.GetRowIdAfterCache(); | |||||
| *row_id_from_server = rq->GetRowIdAfterCache(); | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -47,29 +63,19 @@ Status CacheClient::WriteRow(const TensorRow &row, row_id_type *row_id_from_serv | |||||
| Status CacheClient::WriteBuffer(std::unique_ptr<DataBuffer> &&in) const { | Status CacheClient::WriteBuffer(std::unique_ptr<DataBuffer> &&in) const { | ||||
| std::unique_ptr<DataBuffer> db_ptr = std::move(in); | std::unique_ptr<DataBuffer> db_ptr = std::move(in); | ||||
| auto num_rows = db_ptr->NumRows(); | auto num_rows = db_ptr->NumRows(); | ||||
| std::vector<TensorRow> all_rows; | |||||
| // We will send the requests async first on all rows and do a final wait. | |||||
| if (num_rows > 0) { | if (num_rows > 0) { | ||||
| all_rows.reserve(num_rows); | |||||
| // Break down the DataBuffer into TensorRow. We will send the requests async | |||||
| // and then do a final wait. | |||||
| MemGuard<CacheRowRequest> rq_arr; | |||||
| RETURN_IF_NOT_OK(rq_arr.allocate(num_rows, server_connection_id_, cookie())); | |||||
| CacheServer &cs = CacheServer::GetInstance(); | |||||
| auto arr = std::make_unique<std::shared_ptr<CacheRowRequest>[]>(num_rows); | |||||
| for (auto i = 0; i < num_rows; ++i) { | for (auto i = 0; i < num_rows; ++i) { | ||||
| TensorRow row; | TensorRow row; | ||||
| auto rq = rq_arr[i]; | |||||
| RETURN_IF_NOT_OK(db_ptr->PopRow(&row)); | RETURN_IF_NOT_OK(db_ptr->PopRow(&row)); | ||||
| RETURN_IF_NOT_OK(rq->SerializeCacheRowRequest(row)); | |||||
| RETURN_IF_NOT_OK(cs.PushRequest(rq)); | |||||
| // We can't let row go out of scope. Otherwise it will free all the tensor memory. | |||||
| // So park it in the vector. When this function go out of scope, its memory | |||||
| // will be freed. | |||||
| all_rows.push_back(std::move(row)); | |||||
| arr[i] = std::make_shared<CacheRowRequest>(server_connection_id_, cookie(), SupportLocalClient()); | |||||
| RETURN_IF_NOT_OK(arr[i]->SerializeCacheRowRequest(this, row)); | |||||
| RETURN_IF_NOT_OK(PushRequest(arr[i])); | |||||
| } | } | ||||
| // Now we wait for the requests to be done. | |||||
| // Now we wait for them to come back | |||||
| for (auto i = 0; i < num_rows; ++i) { | for (auto i = 0; i < num_rows; ++i) { | ||||
| auto rq = rq_arr[i]; | |||||
| RETURN_IF_NOT_OK(rq->Wait()); | |||||
| RETURN_IF_NOT_OK(arr[i]->Wait()); | |||||
| } | } | ||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -77,11 +83,21 @@ Status CacheClient::WriteBuffer(std::unique_ptr<DataBuffer> &&in) const { | |||||
| Status CacheClient::GetRows(const std::vector<row_id_type> &row_id, TensorTable *out) const { | Status CacheClient::GetRows(const std::vector<row_id_type> &row_id, TensorTable *out) const { | ||||
| RETURN_UNEXPECTED_IF_NULL(out); | RETURN_UNEXPECTED_IF_NULL(out); | ||||
| BatchFetchRequest rq(server_connection_id_, row_id); | |||||
| RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); | |||||
| RETURN_IF_NOT_OK(rq.Wait()); | |||||
| RETURN_IF_NOT_OK(rq.RestoreRows(out)); | |||||
| return Status::OK(); | |||||
| auto rq = std::make_shared<BatchFetchRequest>(server_connection_id_, row_id, SupportLocalClient()); | |||||
| RETURN_IF_NOT_OK(PushRequest(rq)); | |||||
| RETURN_IF_NOT_OK(rq->Wait()); | |||||
| int64_t mem_addr; | |||||
| Status rc = rq->RestoreRows(out, comm_->SharedMemoryBaseAddr(), &mem_addr); | |||||
| // Free the memory by sending a request back to the server. | |||||
| if (mem_addr != -1) { | |||||
| auto mfree_req = std::make_shared<FreeSharedBlockRequest>(server_connection_id_, mem_addr); | |||||
| Status rc2 = PushRequest(mfree_req); | |||||
| // But we won't wait for the result for the sake of performance. | |||||
| if (rc.IsOk() && rc2.IsError()) { | |||||
| rc = rc2; | |||||
| } | |||||
| } | |||||
| return rc; | |||||
| } | } | ||||
| Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) { | Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) { | ||||
| @@ -108,40 +124,44 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) { | |||||
| // to create a cache and some other tree is trying to use the same cache. | // to create a cache and some other tree is trying to use the same cache. | ||||
| // That is allowed, however the crc better match! | // That is allowed, however the crc better match! | ||||
| if (server_connection_id_) { | if (server_connection_id_) { | ||||
| if (cache_crc_ != tree_crc) { | |||||
| if (cinfo_.crc() != tree_crc) { | |||||
| RETURN_STATUS_UNEXPECTED("Attempt to re-use a cache for a different tree!"); | RETURN_STATUS_UNEXPECTED("Attempt to re-use a cache for a different tree!"); | ||||
| } | } | ||||
| // Check the state of the server. For non-mappable case where there is a build phase and a fetch phase, we should | // Check the state of the server. For non-mappable case where there is a build phase and a fetch phase, we should | ||||
| // skip the build phase. | // skip the build phase. | ||||
| lck.Unlock(); // GetStat will grab the mutex again. So unlock it to prevent deadlock. | lck.Unlock(); // GetStat will grab the mutex again. So unlock it to prevent deadlock. | ||||
| CacheClient::ServiceStat stat{}; | |||||
| CacheServiceStat stat{}; | |||||
| RETURN_IF_NOT_OK(GetStat(&stat)); | RETURN_IF_NOT_OK(GetStat(&stat)); | ||||
| if (stat.cache_service_state == static_cast<uint8_t>(CacheService::State::kFetchPhase)) { | if (stat.cache_service_state == static_cast<uint8_t>(CacheService::State::kFetchPhase)) { | ||||
| return Status(StatusCode::kDuplicateKey, __LINE__, __FILE__, "Not an error and we should bypass the build phase"); | return Status(StatusCode::kDuplicateKey, __LINE__, __FILE__, "Not an error and we should bypass the build phase"); | ||||
| } | } | ||||
| } else { | } else { | ||||
| cache_crc_ = tree_crc; // It's really a new cache we're creating so save our crc in the client | |||||
| // Combine the session and crc. This will form our client cache identifier. | |||||
| connection_id_type connection_identification = (static_cast<uint64_t>(session_id_) << 32) | cache_crc_; | |||||
| cinfo_.set_crc(tree_crc); // It's really a new cache we're creating so save our crc in the client | |||||
| // Now execute the cache create request using this identifier and other configs | // Now execute the cache create request using this identifier and other configs | ||||
| BaseRequest::CreateCacheFlag createFlag = BaseRequest::CreateCacheFlag::kNone; | |||||
| CreateCacheRequest::CreateCacheFlag createFlag = CreateCacheRequest::CreateCacheFlag::kNone; | |||||
| if (spill_) { | if (spill_) { | ||||
| createFlag |= BaseRequest::CreateCacheFlag::kSpillToDisk; | |||||
| createFlag |= CreateCacheRequest::CreateCacheFlag::kSpillToDisk; | |||||
| } | } | ||||
| if (generate_id) { | if (generate_id) { | ||||
| createFlag |= BaseRequest::CreateCacheFlag::kGenerateRowId; | |||||
| createFlag |= CreateCacheRequest::CreateCacheFlag::kGenerateRowId; | |||||
| } | } | ||||
| CreationCacheRequest rq(connection_identification, cache_mem_sz_, createFlag); | |||||
| RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); | |||||
| Status rc = rq.Wait(); | |||||
| // Start the comm layer to receive reply | |||||
| RETURN_IF_NOT_OK(comm_->ServiceStart()); | |||||
| // Initiate connection | |||||
| auto rq = std::make_shared<CreateCacheRequest>(cinfo_, cache_mem_sz_, createFlag); | |||||
| RETURN_IF_NOT_OK(PushRequest(rq)); | |||||
| Status rc = rq->Wait(); | |||||
| if (rc.IsOk() || rc.get_code() == StatusCode::kDuplicateKey) { | if (rc.IsOk() || rc.get_code() == StatusCode::kDuplicateKey) { | ||||
| server_connection_id_ = rq.GetServerConnectionId(); | |||||
| std::string cookie; | |||||
| rq->ParseResult(&server_connection_id_, &cookie); | |||||
| if (rc.IsOk()) { | if (rc.IsOk()) { | ||||
| // The 1st guy creating the cache will get a cookie back. | // The 1st guy creating the cache will get a cookie back. | ||||
| // But this object may be shared among pipelines and we don't want | // But this object may be shared among pipelines and we don't want | ||||
| // overwrite it. | // overwrite it. | ||||
| cookie_ = rq.cookie(); | |||||
| cookie_ = cookie; | |||||
| } | } | ||||
| // Attach to shared memory for local client | |||||
| RETURN_IF_NOT_OK(comm_->AttachToSharedMemory(port_, &local_bypass_)); | |||||
| } | } | ||||
| // We are not resetting the Duplicate key return code. We are passing it back to the CacheOp. This will tell the | // We are not resetting the Duplicate key return code. We are passing it back to the CacheOp. This will tell the | ||||
| // CacheOp to bypass the build phase. | // CacheOp to bypass the build phase. | ||||
| @@ -152,57 +172,57 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) { | |||||
| Status CacheClient::PurgeCache() { | Status CacheClient::PurgeCache() { | ||||
| UniqueLock lck(&mux_); | UniqueLock lck(&mux_); | ||||
| PurgeCacheRequest rq(server_connection_id_); | |||||
| RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); | |||||
| return rq.Wait(); | |||||
| auto rq = std::make_shared<PurgeCacheRequest>(server_connection_id_); | |||||
| RETURN_IF_NOT_OK(PushRequest(rq)); | |||||
| RETURN_IF_NOT_OK(rq->Wait()); | |||||
| return Status::OK(); | |||||
| } | } | ||||
| Status CacheClient::DestroyCache() { | Status CacheClient::DestroyCache() { | ||||
| UniqueLock lck(&mux_); | UniqueLock lck(&mux_); | ||||
| DestroyCacheRequest rq(server_connection_id_); | |||||
| RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); | |||||
| return rq.Wait(); | |||||
| auto rq = std::make_shared<DestroyCacheRequest>(server_connection_id_); | |||||
| RETURN_IF_NOT_OK(PushRequest(rq)); | |||||
| RETURN_IF_NOT_OK(rq->Wait()); | |||||
| return Status::OK(); | |||||
| } | } | ||||
| Status CacheClient::GetStat(ServiceStat *stat) { | |||||
| Status CacheClient::GetStat(CacheServiceStat *stat) { | |||||
| SharedLock lck(&mux_); | SharedLock lck(&mux_); | ||||
| RETURN_UNEXPECTED_IF_NULL(stat); | RETURN_UNEXPECTED_IF_NULL(stat); | ||||
| GetStatRequest rq(server_connection_id_); | |||||
| RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); | |||||
| RETURN_IF_NOT_OK(rq.Wait()); | |||||
| stat->num_disk_cached = rq.GetNumDiskCached(); | |||||
| stat->num_mem_cached = rq.GetNumMemCached(); | |||||
| stat->min_row_id = rq.GetMinRowId(); | |||||
| stat->max_row_id = rq.GetMaxRowId(); | |||||
| stat->cache_service_state = rq.GetState(); | |||||
| auto rq = std::make_shared<GetStatRequest>(server_connection_id_); | |||||
| RETURN_IF_NOT_OK(PushRequest(rq)); | |||||
| RETURN_IF_NOT_OK(rq->Wait()); | |||||
| rq->GetStat(stat); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheClient::CacheSchema(const std::unordered_map<std::string, int32_t> &map) { | Status CacheClient::CacheSchema(const std::unordered_map<std::string, int32_t> &map) { | ||||
| SharedLock lck(&mux_); | SharedLock lck(&mux_); | ||||
| CacheSchemaRequest rq(server_connection_id_); | |||||
| RETURN_IF_NOT_OK(rq.SerializeCacheSchemaRequest(map)); | |||||
| RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); | |||||
| RETURN_IF_NOT_OK(rq.Wait()); | |||||
| auto rq = std::make_shared<CacheSchemaRequest>(server_connection_id_); | |||||
| RETURN_IF_NOT_OK(rq->SerializeCacheSchemaRequest(map)); | |||||
| RETURN_IF_NOT_OK(PushRequest(rq)); | |||||
| RETURN_IF_NOT_OK(rq->Wait()); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheClient::FetchSchema(std::unordered_map<std::string, int32_t> *map) { | Status CacheClient::FetchSchema(std::unordered_map<std::string, int32_t> *map) { | ||||
| SharedLock lck(&mux_); | SharedLock lck(&mux_); | ||||
| RETURN_UNEXPECTED_IF_NULL(map); | RETURN_UNEXPECTED_IF_NULL(map); | ||||
| FetchSchemaRequest rq(server_connection_id_); | |||||
| RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); | |||||
| RETURN_IF_NOT_OK(rq.Wait()); | |||||
| *map = rq.GetColumnMap(); | |||||
| auto rq = std::make_shared<FetchSchemaRequest>(server_connection_id_); | |||||
| RETURN_IF_NOT_OK(PushRequest(rq)); | |||||
| RETURN_IF_NOT_OK(rq->Wait()); | |||||
| *map = rq->GetColumnMap(); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheClient::BuildPhaseDone() const { | Status CacheClient::BuildPhaseDone() const { | ||||
| SharedLock lck(&mux_); | SharedLock lck(&mux_); | ||||
| BuildPhaseDoneRequest rq(server_connection_id_, cookie()); | |||||
| RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); | |||||
| RETURN_IF_NOT_OK(rq.Wait()); | |||||
| auto rq = std::make_shared<BuildPhaseDoneRequest>(server_connection_id_, cookie()); | |||||
| RETURN_IF_NOT_OK(PushRequest(rq)); | |||||
| RETURN_IF_NOT_OK(rq->Wait()); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheClient::PushRequest(std::shared_ptr<BaseRequest> rq) const { return comm_->HandleRequest(std::move(rq)); } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -23,9 +23,13 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/core/config_manager.h" | |||||
| #ifdef ENABLE_CACHE | |||||
| #include "minddata/dataset/engine/cache/cache_grpc_client.h" | |||||
| #else | |||||
| #include "minddata/dataset/engine/cache/stub/cache_grpc_client.h" | |||||
| #endif | |||||
| #include "minddata/dataset/engine/data_buffer.h" | #include "minddata/dataset/engine/data_buffer.h" | ||||
| #include "minddata/dataset/engine/cache/cache_server.h" | |||||
| #include "minddata/dataset/engine/cache/de_tensor_generated.h" | |||||
| #include "minddata/dataset/util/lock.h" | #include "minddata/dataset/util/lock.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -35,18 +39,120 @@ namespace dataset { | |||||
| /// rows, etc. | /// rows, etc. | ||||
| class CacheClient { | class CacheClient { | ||||
| public: | public: | ||||
| friend class CacheMergeOp; | |||||
| /// \brief A builder to help creating a CacheClient object | |||||
| class Builder { | |||||
| public: | |||||
| Builder() : session_id_(0), cache_mem_sz_(0), spill_(false), port_(0), num_workers_(0), prefetch_size_(0) { | |||||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||||
| hostname_ = "127.0.0.1"; | |||||
| port_ = 50052; | |||||
| num_workers_ = cfg->num_parallel_workers(); | |||||
| prefetch_size_ = 20; // rows_per_buf is too small (1 by default). | |||||
| } | |||||
| /// Setter function to set the session id | |||||
| /// \param session_id | |||||
| /// \return Builder object itself. | |||||
| Builder &SetSessionId(session_id_type session_id) { | |||||
| session_id_ = session_id; | |||||
| return *this; | |||||
| } | |||||
| /// Setter function to set the cache memory size | |||||
| /// \param cache_mem_sz | |||||
| /// \return Builder object itself | |||||
| Builder &SetCacheMemSz(uint64_t cache_mem_sz) { | |||||
| cache_mem_sz_ = cache_mem_sz; | |||||
| return *this; | |||||
| } | |||||
| /// Setter function to spill attribute | |||||
| /// \param spill | |||||
| /// Builder object itself | |||||
| Builder &SetSpill(bool spill) { | |||||
| spill_ = spill; | |||||
| return *this; | |||||
| } | |||||
| /// Setter function to set rpc hostname | |||||
| /// \param host | |||||
| /// \return Builder object itself | |||||
| Builder &SetHostname(std::string host) { | |||||
| hostname_ = std::move(host); | |||||
| return *this; | |||||
| } | |||||
| /// Setter function to set tcpip port | |||||
| /// \param port | |||||
| /// \return Builder object itself. | |||||
| Builder &SetPort(int32_t port) { | |||||
| port_ = port; | |||||
| return *this; | |||||
| } | |||||
| /// Setter function to set number of async rpc workers | |||||
| /// \param num_workers | |||||
| /// \return Builder object itself | |||||
| Builder &SetNumWorkers(int32_t num_workers) { | |||||
| num_workers_ = num_workers; | |||||
| return *this; | |||||
| } | |||||
| /// Setter function to set prefetch amount for fetching rows from cache server | |||||
| /// \param prefetch_sz | |||||
| /// \return Builder object itself | |||||
| Builder &SetPrefetchSize(int32_t prefetch_sz) { | |||||
| prefetch_size_ = prefetch_sz; | |||||
| return *this; | |||||
| } | |||||
| /// Getter functions | |||||
| session_id_type getSessionId() const { return session_id_; } | |||||
| uint64_t getCacheMemSz() const { return cache_mem_sz_; } | |||||
| bool isSpill() const { return spill_; } | |||||
| const std::string &getHostname() const { return hostname_; } | |||||
| int32_t getPort() const { return port_; } | |||||
| int32_t getNumWorkers() const { return num_workers_; } | |||||
| int32_t getPrefetchSize() const { return prefetch_size_; } | |||||
| Status SanityCheck() { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(session_id_ > 0, "session id must be positive"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(cache_mem_sz_ >= 0, "cache memory size must not be negative. (0 implies unlimited"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(num_workers_ > 0, "rpc workers must be positive"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(prefetch_size_ > 0, "prefetch size must be positive"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!hostname_.empty(), "hostname must not be empty"); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status Build(std::shared_ptr<CacheClient> *out) { | |||||
| RETURN_UNEXPECTED_IF_NULL(out); | |||||
| RETURN_IF_NOT_OK(SanityCheck()); | |||||
| *out = std::make_shared<CacheClient>(session_id_, cache_mem_sz_, spill_, hostname_, port_, num_workers_, | |||||
| prefetch_size_); | |||||
| return Status::OK(); | |||||
| } | |||||
| private: | |||||
| session_id_type session_id_; | |||||
| uint64_t cache_mem_sz_; | |||||
| bool spill_; | |||||
| std::string hostname_; | |||||
| int32_t port_; | |||||
| int32_t num_workers_; | |||||
| int32_t prefetch_size_; | |||||
| }; | |||||
| /// \brief Constructor | /// \brief Constructor | ||||
| /// \param session_id A user assigned session id for the current pipeline | /// \param session_id A user assigned session id for the current pipeline | ||||
| /// \param cache_mem_sz Size of the memory set aside for the row caching. 0 for unlimited | /// \param cache_mem_sz Size of the memory set aside for the row caching. 0 for unlimited | ||||
| /// \param spill Spill to disk if out of memory | /// \param spill Spill to disk if out of memory | ||||
| CacheClient(uint32_t session_id, uint64_t cache_mem_sz, bool spill); | |||||
| CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool spill, std::string hostname, int32_t port, | |||||
| int32_t num_workers, int32_t prefetch_size); | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~CacheClient() = default; | |||||
| /// \brief Getter function for returning the current session id | |||||
| /// \return session id | |||||
| uint64_t session_id() const { return session_id_; } | |||||
| ~CacheClient() { (void)comm_->ServiceStop(); } | |||||
| /// \brief Send a TensorRow to the cache server | /// \brief Send a TensorRow to the cache server | ||||
| /// \param[in] row | /// \param[in] row | ||||
| @@ -83,14 +189,7 @@ class CacheClient { | |||||
| /// \brief Get the statistics from a cache. | /// \brief Get the statistics from a cache. | ||||
| /// \param[in/out] Pointer to a pre-allocated ServiceStat object | /// \param[in/out] Pointer to a pre-allocated ServiceStat object | ||||
| /// \return Status object | /// \return Status object | ||||
| struct ServiceStat { | |||||
| int64_t num_mem_cached; | |||||
| int64_t num_disk_cached; | |||||
| row_id_type min_row_id; | |||||
| row_id_type max_row_id; | |||||
| int8_t cache_service_state; | |||||
| }; | |||||
| Status GetStat(ServiceStat *); | |||||
| Status GetStat(CacheServiceStat *); | |||||
| /// \brief Cache the schema at the cache server | /// \brief Cache the schema at the cache server | ||||
| /// \param map The unordered map of the schema | /// \param map The unordered map of the schema | ||||
| @@ -122,18 +221,45 @@ class CacheClient { | |||||
| /// \return Cookie | /// \return Cookie | ||||
| std::string cookie() const { return cookie_; } | std::string cookie() const { return cookie_; } | ||||
| /// \brief Send a request async to the server | |||||
| /// \param rq BaseRequest | |||||
| /// \return Status object | |||||
| Status PushRequest(std::shared_ptr<BaseRequest> rq) const; | |||||
| /// \brief If the remote server supports local bypass using shared memory | |||||
| /// \return boolean value | |||||
| bool SupportLocalClient() const { return local_bypass_; } | |||||
| /// \brief Return the base memory address if we attach to any shared memory. | |||||
| auto SharedMemoryBaseAddr() const { return comm_->SharedMemoryBaseAddr(); } | |||||
| /// Getter functions | |||||
| session_id_type session_id() const { return cinfo_.session_id(); } | |||||
| uint64_t getCacheMemSz() const { return cache_mem_sz_; } | |||||
| bool isSpill() const { return spill_; } | |||||
| const std::string &getHostname() const { return hostname_; } | |||||
| int32_t getPort() const { return port_; } | |||||
| int32_t getNumWorkers() const { return num_workers_; } | |||||
| int32_t getPrefetchSize() const { return prefetch_size_; } | |||||
| private: | private: | ||||
| mutable RWLock mux_; | mutable RWLock mux_; | ||||
| uint64_t cache_mem_sz_; | uint64_t cache_mem_sz_; | ||||
| bool spill_; | bool spill_; | ||||
| // The session_id_ and cache_crc_ work together to uniquely identify this particular cache and allow | // The session_id_ and cache_crc_ work together to uniquely identify this particular cache and allow | ||||
| // sharing of the cache. | // sharing of the cache. | ||||
| uint32_t session_id_; | |||||
| uint32_t cache_crc_; | |||||
| CacheClientInfo cinfo_; | |||||
| // The server_connection_id_ is the actual id we use for operations after the cache is built | // The server_connection_id_ is the actual id we use for operations after the cache is built | ||||
| connection_id_type server_connection_id_; | connection_id_type server_connection_id_; | ||||
| // Some magic cookie returned from the cache server. | // Some magic cookie returned from the cache server. | ||||
| std::string cookie_; | std::string cookie_; | ||||
| // Comm layer | |||||
| bool local_bypass_; | |||||
| std::string hostname_; | |||||
| int32_t port_; | |||||
| int32_t num_workers_; | |||||
| int32_t prefetch_size_; | |||||
| mutable std::shared_ptr<CacheClientGreeter> comm_; | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,90 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_ | |||||
| /// \note This header file contains common header files and some inlines used by | |||||
| /// both client and server side codes. Do not put code that is not common here. | |||||
| /// There are client and server specific header files. | |||||
| // On platform like Windows, we may support only tcp/ip clients | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| #define CACHE_LOCAL_CLIENT 1 | |||||
| #endif | |||||
| #ifdef CACHE_LOCAL_CLIENT | |||||
| #include <sys/types.h> | |||||
| #include <sys/ipc.h> | |||||
| #include <sys/shm.h> | |||||
| #else | |||||
| typedef int key_t; | |||||
| #endif | |||||
| #ifdef ENABLE_CACHE | |||||
| #include <grpcpp/grpcpp.h> | |||||
| #endif | |||||
| #include <string> | |||||
| #ifdef ENABLE_CACHE | |||||
| #include "proto/cache_grpc.grpc.pb.h" | |||||
| #endif | |||||
| #include "proto/cache_grpc.pb.h" | |||||
| #include "minddata/dataset/engine/cache/cache_request.h" | |||||
| #include "minddata/dataset/engine/cache/de_tensor_generated.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| /// \brief CacheRow and BatchFetch requests will switch to use shared memory method (if supported | |||||
| /// on the platform) when the amount of bytes sent is greater than the following number. | |||||
| /// For too small amount, we won't get any benefit using shared memory method because we need | |||||
| /// two rpc requests to use shared memory method. | |||||
| constexpr static int32_t kLocalByPassThreshold = 64 * 1024; | |||||
| /// \brief A flag used by the BatchFetch request (client side) if it can support local bypass | |||||
| constexpr static uint32_t kLocalClientSupport = 1; | |||||
| /// \brief A flag used by CacheRow request (client side) and BatchFetch (server side) reply to indicate if the data is | |||||
| /// inline in the protobuf. This also implies kLocalClientSupport is also true. | |||||
| constexpr static uint32_t kDataIsInSharedMemory = 2; | |||||
| /// \brief Convert a Status object into a protobuf | |||||
| /// \param rc[in] Status object | |||||
| /// \param reply[in/out] pointer to pre-allocated protobuf object | |||||
| inline void Status2CacheReply(const Status &rc, CacheReply *reply) { | |||||
| reply->set_rc(static_cast<google::int32>(rc.get_code())); | |||||
| reply->set_msg(rc.ToString()); | |||||
| } | |||||
| /// \brief Generate the unix socket file we use on both client/server side given a tcp/ip port number | |||||
| /// \param port | |||||
| /// \return unix socket url | |||||
| inline std::string PortToUnixSocketPath(int port) { return "/tmp/cache_server_p" + std::to_string(port); } | |||||
| /// \brief Generate a shared memory key using the tcp/ip port. | |||||
| /// \note It must be called after the cache server generates the unix socket or ftok will fail. | |||||
| /// \note Caller must check the return value. -1 means ftok failed. | |||||
| /// \param[in] port | |||||
| /// \param[out] err. If not null and ftok fails, this will contain the value of errno | |||||
| /// \return key | |||||
| inline key_t PortToFtok(int port, int *err) { | |||||
| key_t shmkey = -1; | |||||
| #ifdef CACHE_LOCAL_CLIENT | |||||
| const std::string unix_path = PortToUnixSocketPath(port); | |||||
| shmkey = ftok(unix_path.data(), 'a'); | |||||
| if (err != nullptr && shmkey == (key_t)-1) { | |||||
| *err = errno; | |||||
| } | |||||
| #endif | |||||
| return shmkey; | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_ | |||||
| @@ -0,0 +1,151 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/cache/cache_fbb.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| /// A private function used by SerializeTensorRowHeader to serialize each column in a tensor | |||||
| /// \note Not to be called by outside world | |||||
| /// \return Status object | |||||
| Status SerializeOneTensorMeta(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb, | |||||
| const std::shared_ptr<Tensor> &ts_ptr, flatbuffers::Offset<TensorMetaMsg> *out_off) { | |||||
| RETURN_UNEXPECTED_IF_NULL(out_off); | |||||
| const Tensor *ts = ts_ptr.get(); | |||||
| auto shape_off = fbb->CreateVector(ts->shape().AsVector()); | |||||
| const auto ptr = ts->GetBuffer(); | |||||
| if (ptr == nullptr) { | |||||
| RETURN_STATUS_UNEXPECTED("Tensor buffer is null"); | |||||
| } | |||||
| auto src = ts->type().value(); | |||||
| TensorType dest; | |||||
| #define CASE(t) \ | |||||
| case DataType::t: \ | |||||
| dest = TensorType::TensorType_##t; \ | |||||
| break | |||||
| // Map the type to fill in the flat buffer. | |||||
| switch (src) { | |||||
| CASE(DE_BOOL); | |||||
| CASE(DE_INT8); | |||||
| CASE(DE_UINT8); | |||||
| CASE(DE_INT16); | |||||
| CASE(DE_UINT16); | |||||
| CASE(DE_INT32); | |||||
| CASE(DE_UINT32); | |||||
| CASE(DE_INT64); | |||||
| CASE(DE_UINT64); | |||||
| CASE(DE_FLOAT16); | |||||
| CASE(DE_FLOAT32); | |||||
| CASE(DE_FLOAT64); | |||||
| CASE(DE_STRING); | |||||
| default: | |||||
| MS_LOG(ERROR) << "Unknown tensor. Dumping content:\n" << *ts; | |||||
| RETURN_STATUS_UNEXPECTED("Unknown type"); | |||||
| } | |||||
| #undef CASE | |||||
| TensorMetaMsgBuilder ts_builder(*fbb); | |||||
| ts_builder.add_dims(shape_off); | |||||
| ts_builder.add_type(dest); | |||||
| auto ts_off = ts_builder.Finish(); | |||||
| *out_off = ts_off; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status SerializeTensorRowHeader(const TensorRow &row, std::shared_ptr<flatbuffers::FlatBufferBuilder> *out_fbb) { | |||||
| RETURN_UNEXPECTED_IF_NULL(out_fbb); | |||||
| auto fbb = std::make_shared<flatbuffers::FlatBufferBuilder>(); | |||||
| try { | |||||
| fbb = std::make_shared<flatbuffers::FlatBufferBuilder>(); | |||||
| std::vector<flatbuffers::Offset<TensorMetaMsg>> v; | |||||
| std::vector<int64_t> tensor_sz; | |||||
| v.reserve(row.size()); | |||||
| tensor_sz.reserve(row.size()); | |||||
| // We will go through each column in the row. | |||||
| for (const std::shared_ptr<Tensor> &ts_ptr : row) { | |||||
| flatbuffers::Offset<TensorMetaMsg> ts_off; | |||||
| RETURN_IF_NOT_OK(SerializeOneTensorMeta(fbb, ts_ptr, &ts_off)); | |||||
| v.push_back(ts_off); | |||||
| tensor_sz.push_back(ts_ptr->SizeInBytes()); | |||||
| } | |||||
| auto column_off = fbb->CreateVector(v); | |||||
| auto data_sz_off = fbb->CreateVector(tensor_sz); | |||||
| TensorRowHeaderMsgBuilder row_builder(*fbb); | |||||
| row_builder.add_column(column_off); | |||||
| row_builder.add_data_sz(data_sz_off); | |||||
| // Pass the row_id even if it may not be known. | |||||
| row_builder.add_row_id(row.getId()); | |||||
| row_builder.add_size_of_this(-1); // fill in later after we call Finish. | |||||
| auto out = row_builder.Finish(); | |||||
| fbb->Finish(out); | |||||
| // Now go back to fill in size_of_this in the flat buffer. | |||||
| auto msg = GetMutableTensorRowHeaderMsg(fbb->GetBufferPointer()); | |||||
| auto success = msg->mutate_size_of_this(fbb->GetSize()); | |||||
| if (!success) { | |||||
| RETURN_STATUS_UNEXPECTED("Unable to set size_of_this"); | |||||
| } | |||||
| (*out_fbb) = std::move(fbb); | |||||
| return Status::OK(); | |||||
| } catch (const std::bad_alloc &e) { | |||||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); | |||||
| } | |||||
| } | |||||
| Status RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, std::shared_ptr<Tensor> *out) { | |||||
| RETURN_UNEXPECTED_IF_NULL(col_ts); | |||||
| auto shape_in = col_ts->dims(); | |||||
| auto type_in = col_ts->type(); | |||||
| std::vector<dsize_t> v; | |||||
| v.reserve(shape_in->size()); | |||||
| v.assign(shape_in->begin(), shape_in->end()); | |||||
| TensorShape shape(v); | |||||
| DataType::Type dest = DataType::DE_UNKNOWN; | |||||
| #define CASE(t) \ | |||||
| case TensorType_##t: \ | |||||
| dest = DataType::Type::t; \ | |||||
| break | |||||
| switch (type_in) { | |||||
| CASE(DE_BOOL); | |||||
| CASE(DE_INT8); | |||||
| CASE(DE_UINT8); | |||||
| CASE(DE_INT16); | |||||
| CASE(DE_UINT16); | |||||
| CASE(DE_INT32); | |||||
| CASE(DE_UINT32); | |||||
| CASE(DE_INT64); | |||||
| CASE(DE_UINT64); | |||||
| CASE(DE_FLOAT16); | |||||
| CASE(DE_FLOAT32); | |||||
| CASE(DE_FLOAT64); | |||||
| CASE(DE_STRING); | |||||
| } | |||||
| #undef CASE | |||||
| DataType type(dest); | |||||
| std::shared_ptr<Tensor> ts; | |||||
| RETURN_IF_NOT_OK( | |||||
| Tensor::CreateFromMemory(shape, type, static_cast<const unsigned char *>(data.GetPointer()), data.GetSize(), &ts)); | |||||
| // Next we restore the real data which can be embedded or stored separately. | |||||
| if (ts->SizeInBytes() != data.GetSize()) { | |||||
| MS_LOG(ERROR) << "Unexpected length. Read " << data.GetSize() << ". Expected " << ts->SizeInBytes() << ".\n" | |||||
| << "Dumping tensor\n" | |||||
| << *ts << "\n"; | |||||
| RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details."); | |||||
| } | |||||
| *out = std::move(ts); | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,46 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_FBB_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_FBB_H_ | |||||
| /// This header contains some serialize and deserialize functions for tensor row using | |||||
| /// Google Flatbuffer | |||||
| #include <memory> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/cache/de_tensor_generated.h" | |||||
| #include "minddata/dataset/core/tensor_row.h" | |||||
| #include "minddata/dataset/util/slice.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| /// \brief Function to serialize TensorRow header used by CacheRowRequest | |||||
| /// \param row TensorRow | |||||
| /// \param fbb [in/out] fbb that contains the serialized data | |||||
| /// \return Status object | |||||
| Status SerializeTensorRowHeader(const TensorRow &row, std::shared_ptr<flatbuffers::FlatBufferBuilder> *fbb); | |||||
| /// \brief A function used by BatchFetchRequest to deserialize a flat buffer back to a tensor row. | |||||
| /// \param col_ts A serialized version of Tensor meta data | |||||
| /// \param data Tensor data wrapped in a slice | |||||
| /// \param out Tensor | |||||
| /// \return Status object | |||||
| Status RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, std::shared_ptr<Tensor> *out); | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_FBB_H_ | |||||
| @@ -0,0 +1,54 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| package mindspore.dataset; | |||||
| option cc_enable_arenas = true; | |||||
| // The session_id and crc work together to uniquely identify this particular cache and allow | |||||
| // sharing of the cache. | |||||
| message CacheClientInfo { | |||||
| uint32 session_id = 1; | |||||
| uint32 crc = 2; | |||||
| } | |||||
| message CacheRequest { | |||||
| // Type of rpc request | |||||
| int32 type = 1; | |||||
| // Extra optional flag used by individual request if needed | |||||
| uint32 flag = 2; | |||||
| oneof connect_info { | |||||
| // The server_connection_id is the actual id we use for operations after the cache is built | |||||
| int64 connection_id = 3; | |||||
| // But some request like CreateCache we have to use the session id and crc to connect to the server. | |||||
| CacheClientInfo connection_info = 4; | |||||
| } | |||||
| // Everything else is just vector of buffers | |||||
| repeated bytes buf_data = 5; | |||||
| } | |||||
| message CacheReply { | |||||
| int32 rc = 1; | |||||
| string msg = 2; | |||||
| // Extra optional flag used by individual request if needed | |||||
| uint32 flag = 3; | |||||
| // What the server send back is a plain buffer | |||||
| bytes result = 4; | |||||
| } | |||||
| service CacheServerGreeter { | |||||
| rpc CacheServerRequest (CacheRequest) returns (CacheReply) {} | |||||
| } | |||||
| @@ -0,0 +1,161 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/cache/cache_grpc_client.h" | |||||
| #include <chrono> | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| Status CacheClientRequestTag::MakeCall(CacheServerGreeter::Stub *stub, grpc::CompletionQueue *cq, | |||||
| std::unique_ptr<CacheClientRequestTag> &&tag) { | |||||
| // If there is anything extra we need to do before we send. | |||||
| RETURN_IF_NOT_OK(tag->base_rq_->Prepare()); | |||||
| // One minute timeout | |||||
| auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60); | |||||
| tag->ctx_.set_deadline(deadline); | |||||
| tag->rpc_ = stub->PrepareAsyncCacheServerRequest(&tag->ctx_, tag->base_rq_->rq_, cq); | |||||
| tag->rpc_->StartCall(); | |||||
| // Last step is we release the ownership and transfer it to the completion queue. | |||||
| // The memory will be released by WorkerEntry or by the destructor when we drain the queue | |||||
| auto ccReqTag = tag.release(); | |||||
| ccReqTag->rpc_->Finish(&ccReqTag->base_rq_->reply_, &ccReqTag->rc_, | |||||
| ccReqTag); // inject this object into the completion queue | |||||
| return Status::OK(); | |||||
| } | |||||
| CacheClientGreeter::~CacheClientGreeter() { | |||||
| (void)ServiceStop(); | |||||
| // Detach from shared memory if any | |||||
| if (shmat_addr_ != nullptr) { | |||||
| shmdt(shmat_addr_); | |||||
| shmat_addr_ = nullptr; | |||||
| } | |||||
| } | |||||
| CacheClientGreeter::CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_workers) | |||||
| : num_workers_(num_workers), shm_key_(-1), shm_id_(-1), shmat_addr_(nullptr) { | |||||
| grpc::ChannelArguments args; | |||||
| // We need to bump up the message size to unlimited. The default receiving | |||||
| // message limit is 4MB which is not big enough. | |||||
| args.SetMaxReceiveMessageSize(-1); | |||||
| #if CACHE_LOCAL_CLIENT | |||||
| // Try connect locally to the unix_socket first as the first preference | |||||
| // Need to resolve hostname to ip address rather than to do a string compare | |||||
| if (hostname == "127.0.0.1") { | |||||
| std::string target = "unix://" + PortToUnixSocketPath(port); | |||||
| channel_ = grpc::CreateCustomChannel(target, grpc::InsecureChannelCredentials(), args); | |||||
| } else { | |||||
| #endif | |||||
| std::string target = hostname + ":" + std::to_string(port); | |||||
| channel_ = grpc::CreateCustomChannel(target, grpc::InsecureChannelCredentials(), args); | |||||
| #if CACHE_LOCAL_CLIENT | |||||
| } | |||||
| #endif | |||||
| stub_ = CacheServerGreeter::NewStub(channel_); | |||||
| } | |||||
| Status CacheClientGreeter::AttachToSharedMemory(int32_t port, bool *local_bypass) { | |||||
| *local_bypass = false; | |||||
| #if CACHE_LOCAL_CLIENT | |||||
| int err; | |||||
| shm_key_ = PortToFtok(port, &err); | |||||
| if (shm_key_ == (key_t)-1) { | |||||
| std::string errMsg = "Ftok failed with errno " + std::to_string(err); | |||||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||||
| } | |||||
| // Attach to the shared memory | |||||
| shm_id_ = shmget(shm_key_, 0, 0); | |||||
| if (shm_id_ == -1) { | |||||
| RETURN_STATUS_UNEXPECTED("Shmget failed. Errno " + std::to_string(errno)); | |||||
| } | |||||
| shmat_addr_ = shmat(shm_id_, nullptr, 0); | |||||
| if (shmat_addr_ == reinterpret_cast<void *>(-1)) { | |||||
| RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno)); | |||||
| } | |||||
| *local_bypass = true; | |||||
| #endif | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheClientGreeter::DoServiceStart() { | |||||
| RETURN_IF_NOT_OK(vg_.ServiceStart()); | |||||
| RETURN_IF_NOT_OK(DispatchWorkers(num_workers_)); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheClientGreeter::DoServiceStop() { | |||||
| // Shutdown the queue. We don't accept any more new incomers. | |||||
| cq_.Shutdown(); | |||||
| // Shutdown the TaskGroup. | |||||
| vg_.interrupt_all(); | |||||
| vg_.join_all(Task::WaitFlag::kNonBlocking); | |||||
| // Drain the queue | |||||
| bool success; | |||||
| void *tag; | |||||
| while (cq_.Next(&tag, &success)) { | |||||
| auto r = reinterpret_cast<CacheClientRequestTag *>(tag); | |||||
| delete r; | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheClientGreeter::HandleRequest(std::shared_ptr<BaseRequest> rq) { | |||||
| auto tag = std::make_unique<CacheClientRequestTag>(std::move(rq)); | |||||
| return tag->MakeCall(stub_.get(), &cq_, std::move(tag)); | |||||
| } | |||||
| Status CacheClientGreeter::WorkerEntry() { | |||||
| TaskManager::FindMe()->Post(); | |||||
| do { | |||||
| bool success; | |||||
| void *tag; | |||||
| auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(1); | |||||
| // Set a timeout for one second. Check for interrupt if we need to do early exit. | |||||
| auto r = cq_.AsyncNext(&tag, &success, deadline); | |||||
| if (r == grpc_impl::CompletionQueue::NextStatus::GOT_EVENT) { | |||||
| auto rq = reinterpret_cast<CacheClientRequestTag *>(tag); | |||||
| if (success) { | |||||
| auto &rc = rq->rc_; | |||||
| if (!rc.ok()) { | |||||
| auto error_code = rq->rc_.error_code(); | |||||
| std::string errMsg = rq->rc_.error_message() + ". GRPC Code " + std::to_string(error_code); | |||||
| Status remote_rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||||
| Status2CacheReply(remote_rc, &rq->base_rq_->reply_); | |||||
| } | |||||
| // Notify the waiting thread. | |||||
| rq->Notify(); | |||||
| } | |||||
| // We can now free the memory | |||||
| delete rq; | |||||
| } else if (r == grpc_impl::CompletionQueue::NextStatus::TIMEOUT) { | |||||
| // If we are interrupted, exit. Otherwise wait again. | |||||
| RETURN_IF_INTERRUPTED(); | |||||
| } else { | |||||
| // Queue is drained. | |||||
| break; | |||||
| } | |||||
| } while (true); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheClientGreeter::DispatchWorkers(int32_t num_workers) { | |||||
| auto f = std::bind(&CacheClientGreeter::WorkerEntry, this); | |||||
| for (auto i = 0; i < num_workers; ++i) { | |||||
| RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Async reply", f)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,102 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include "minddata/dataset/engine/cache/cache_common.h" | |||||
| #include "minddata/dataset/util/service.h" | |||||
| #include "minddata/dataset/util/task_manager.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| /// \brief A client view of gRPC request | |||||
| /// Like the class CacheServerRequest, this is used as a tag to inject into the gRPC | |||||
| /// completion queue. The thread that makes the rpc request will wait on a wait post | |||||
| /// area for the reply to come back. Since this tag will be deleted from memory and | |||||
| /// we thus we need to work on a shared pointer of the BaseRequest such that its | |||||
| /// use count is at least two. Otherwise either thread will be referencing stale memory. | |||||
| /// \see CacheServerRequest | |||||
| class CacheClientRequestTag { | |||||
| public: | |||||
| friend class CacheClientGreeter; | |||||
| explicit CacheClientRequestTag(std::shared_ptr<BaseRequest> rq) : base_rq_(std::move(rq)) {} | |||||
| ~CacheClientRequestTag() = default; | |||||
| /// \brief Make a RPC call | |||||
| /// \param stub from CacheClientGreeter | |||||
| /// \param cq from CacheClientGreeter | |||||
| /// \return Status object | |||||
| static Status MakeCall(CacheServerGreeter::Stub *stub, grpc::CompletionQueue *cq, | |||||
| std::unique_ptr<CacheClientRequestTag> &&tag); | |||||
| /// \brief Notify the client that a result has come back from the server | |||||
| void Notify() { base_rq_->wp_.Set(); } | |||||
| private: | |||||
| std::shared_ptr<BaseRequest> base_rq_; | |||||
| grpc::Status rc_; | |||||
| grpc::ClientContext ctx_; | |||||
| std::unique_ptr<grpc::ClientAsyncResponseReader<CacheReply>> rpc_; | |||||
| }; | |||||
| /// \brief A GRPC layer to convert BaseRequest into protobuf and send to the cache server using gRPC | |||||
| /// \see BaseRequest | |||||
| class CacheClientGreeter : public Service { | |||||
| friend class CacheClient; | |||||
| public: | |||||
| explicit CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_workers); | |||||
| ~CacheClientGreeter(); | |||||
| /// Override base Service class | |||||
| Status DoServiceStart() override; | |||||
| Status DoServiceStop() override; | |||||
| /// \brief Send the request to the server | |||||
| /// \return Status object | |||||
| Status HandleRequest(std::shared_ptr<BaseRequest> rq); | |||||
| /// \brief A handful of threads will be handling async reply from the server | |||||
| /// \return | |||||
| Status WorkerEntry(); | |||||
| /// \brief Kick off threads to receive reply from the server | |||||
| Status DispatchWorkers(int32_t num_workers); | |||||
| /// \brief Attach to shared memory for local client | |||||
| /// \note Called after we have established a connection. | |||||
| /// \return Status object. | |||||
| Status AttachToSharedMemory(int32_t port, bool *local_bypass); | |||||
| /// \brief This returns where we attach to the shared memory. | |||||
| /// \return Base address of the shared memory. | |||||
| const void *SharedMemoryBaseAddr() const { return shmat_addr_; } | |||||
| private: | |||||
| std::shared_ptr<grpc::Channel> channel_; | |||||
| std::unique_ptr<CacheServerGreeter::Stub> stub_; | |||||
| grpc::CompletionQueue cq_; | |||||
| TaskGroup vg_; | |||||
| int32_t num_workers_; | |||||
| key_t shm_key_; | |||||
| int32_t shm_id_; | |||||
| void *shmat_addr_; | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_ | |||||
| @@ -0,0 +1,203 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/cache/cache_grpc_server.h" | |||||
| #include <limits> | |||||
| #include "minddata/dataset/engine/cache/cache_server.h" | |||||
| #include "minddata/dataset/util/path.h" | |||||
| #include "utils/log_adapter.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| CacheServerGreeterImpl::CacheServerGreeterImpl(int32_t port, int32_t shared_memory_sz_in_gb) | |||||
| : port_(port), shm_pool_sz_in_gb_(shared_memory_sz_in_gb) { | |||||
| // Setup a path for unix socket. | |||||
| unix_socket_ = PortToUnixSocketPath(port); | |||||
| // We can't generate the ftok key yet until the unix_socket_ is created | |||||
| } | |||||
| void CacheServerGreeterImpl::Shutdown() { | |||||
| if (server_) { | |||||
| auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(1); | |||||
| server_->Shutdown(deadline); | |||||
| } | |||||
| // Always shutdown the completion queue after the server. | |||||
| if (cq_) { | |||||
| cq_->Shutdown(); | |||||
| // We need to drain the queue. All the tag is coming from | |||||
| // the Services pool which will be shutdown as well. So we | |||||
| // ignore the tag. | |||||
| void *tag; | |||||
| bool success; | |||||
| while (cq_->Next(&tag, &success)) { | |||||
| } | |||||
| } | |||||
| } | |||||
| CacheServerGreeterImpl::~CacheServerGreeterImpl() { Shutdown(); } | |||||
| Status CacheServerGreeterImpl::IpcResourceCleanup() { | |||||
| #if CACHE_LOCAL_CLIENT | |||||
| int err; | |||||
| auto shm_key = PortToFtok(port_, &err); | |||||
| // We are expecting the unix path doesn't exist. | |||||
| if (shm_key == (key_t)-1) { | |||||
| return Status::OK(); | |||||
| } | |||||
| // Attach to the shared memory | |||||
| auto shm_id = shmget(shm_key, 0, 0); | |||||
| if (shm_id == -1) { | |||||
| return Status::OK(); | |||||
| } | |||||
| struct shmid_ds ds {}; | |||||
| auto inx = shmctl(shm_id, IPC_STAT, &ds); | |||||
| if (inx == -1) { | |||||
| std::string errMsg = "Unable to query shared memory with id " + std::to_string(shm_id); | |||||
| errMsg += "\nPlesae remove it manually using ipcrm -m command"; | |||||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||||
| } | |||||
| if (ds.shm_nattch == 0) { | |||||
| // Stale shared memory from last time. | |||||
| // Remove both the memory and the socket path | |||||
| inx = shmctl(shm_id, IPC_RMID, nullptr); | |||||
| if (inx == -1) { | |||||
| std::string errMsg = "Unable to remove shared memory with id " + std::to_string(shm_id); | |||||
| errMsg += ". Errno :" + std::to_string(errno); | |||||
| errMsg += "\nPlesae remove it manually using ipcrm -m command"; | |||||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||||
| } | |||||
| Path p(unix_socket_); | |||||
| (void)p.Remove(); | |||||
| } else { | |||||
| // Server is already up. | |||||
| MS_LOG(ERROR) << "Cache server is already up and running"; | |||||
| // We return a duplicate error. The main() will intercept | |||||
| // and output a proper message | |||||
| return Status(StatusCode::kDuplicateKey); | |||||
| } | |||||
| #endif | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheServerGreeterImpl::Run() { | |||||
| // To listen on all interfaces, use 0.0.0.0 | |||||
| // Use 127.0.0.1 if just locally on the same machine. | |||||
| std::string host("0.0.0.0"); // listen on all interfaces. | |||||
| std::string server_address = host + ":" + std::to_string(port_); | |||||
| grpc::ServerBuilder builder; | |||||
| // Default message size for gRPC is 4MB. Increase it to 2g-1 | |||||
| builder.SetMaxReceiveMessageSize(std::numeric_limits<int32_t>::max()); | |||||
| int port_tcpip = 0; | |||||
| #if CACHE_LOCAL_CLIENT | |||||
| int port_local = 0; | |||||
| // Check if we need to do clean up on the shared memory if the server | |||||
| // came down unexpectedly like SEGV | |||||
| RETURN_IF_NOT_OK(IpcResourceCleanup()); | |||||
| // We also optimize on local clients on the same machine using unix socket | |||||
| builder.AddListeningPort("unix://" + unix_socket_, grpc::InsecureServerCredentials(), &port_local); | |||||
| #endif | |||||
| builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_tcpip); | |||||
| builder.RegisterService(&svc_); | |||||
| cq_ = builder.AddCompletionQueue(); | |||||
| server_ = builder.BuildAndStart(); | |||||
| if (server_) { | |||||
| MS_LOG(INFO) << "Server listening on " << server_address; | |||||
| #if CACHE_LOCAL_CLIENT | |||||
| RETURN_IF_NOT_OK(CachedSharedMemoryArena::CreateArena(&shm_pool_, port_, shm_pool_sz_in_gb_)); | |||||
| MS_LOG(INFO) << "Creation of local socket and shared memory successful"; | |||||
| #endif | |||||
| } else { | |||||
| std::string errMsg = "Fail to start server. "; | |||||
| if (port_tcpip != port_) { | |||||
| errMsg += "Unable to bind to tcpip port " + std::to_string(port_) + "."; | |||||
| } | |||||
| #if CACHE_LOCAL_CLIENT | |||||
| if (port_local == 0) { | |||||
| errMsg += " Unable to create unix socket " + unix_socket_ + "."; | |||||
| } | |||||
| #endif | |||||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheServerGreeterImpl::HandleRequest(int32_t worker_id) { | |||||
| bool success; | |||||
| void *tag; | |||||
| // We loop through the grpc queue. Each connection if successful | |||||
| // will come back with our own tag which is an instance of CacheServerRequest | |||||
| // and we simply call its functor. But first we need to create these instances | |||||
| // and inject them into the grpc queue. | |||||
| CacheServerRequest *p; | |||||
| // Get a free tag from my free list. | |||||
| RETURN_IF_NOT_OK(CacheServer::GetFreeRequestTag(worker_id, &p)); | |||||
| RETURN_IF_NOT_OK((*p)(&svc_, cq_.get())); | |||||
| do { | |||||
| auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(1); | |||||
| // Set a timeout for one second. Check for interrupt if we need to do early exit. | |||||
| auto r = cq_->AsyncNext(&tag, &success, deadline); | |||||
| if (r == grpc_impl::CompletionQueue::NextStatus::GOT_EVENT) { | |||||
| if (success) { | |||||
| auto rq = static_cast<CacheServerRequest *>(tag); | |||||
| RETURN_IF_NOT_OK((*rq)(&svc_, cq_.get())); | |||||
| } | |||||
| } else if (r == grpc_impl::CompletionQueue::NextStatus::TIMEOUT) { | |||||
| // If we are interrupted, exit. Otherwise wait again. | |||||
| RETURN_IF_INTERRUPTED(); | |||||
| } else { | |||||
| // Queue is drained. | |||||
| break; | |||||
| } | |||||
| } while (true); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheServerRequest::operator()(CacheServerGreeter::AsyncService *svc, grpc::ServerCompletionQueue *cq) { | |||||
| auto myQID = getQid(); | |||||
| if (st_ == STATE::CREATE) { | |||||
| st_ = STATE::PROCESS; | |||||
| svc->RequestCacheServerRequest(&ctx_, &rq_, &responder_, cq, cq, this); | |||||
| } else if (st_ == STATE::PROCESS) { | |||||
| // Get a new tag and handle the next request before we serve the current request. | |||||
| // The tag will be recycled when its state is changed to FINISH | |||||
| CacheServerRequest *next_rq; | |||||
| RETURN_IF_NOT_OK(CacheServer::GetFreeRequestTag(myQID, &next_rq)); | |||||
| RETURN_IF_NOT_OK((*next_rq)(svc, cq)); | |||||
| // Now we continue with the current request. | |||||
| // First thing we need to extract the type from the incoming request. | |||||
| // When this object was first created (i.e. STATE::CREATE), we set the type to UNKNOWN. | |||||
| type_ = static_cast<RequestType>(rq_.type()); | |||||
| // Now we pass the address of this instance to CacheServer's main loop. | |||||
| MS_LOG(DEBUG) << "Handle request " << *this; | |||||
| auto &cs = CacheServer::GetInstance(); | |||||
| RETURN_IF_NOT_OK(cs.PushRequest(myQID, this)); | |||||
| } else if (st_ == STATE::FINISH) { | |||||
| MS_LOG(DEBUG) << *this << " Finished."; | |||||
| // Return back to the free list. | |||||
| RETURN_IF_NOT_OK(CacheServer::ReturnRequestTag(this)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| void CacheServerRequest::Print(std::ostream &out) const { | |||||
| if (rq_.has_connection_info()) { | |||||
| out << "Session Id: " << rq_.connection_info().session_id() << " CRC: " << rq_.connection_info().crc(); | |||||
| } else { | |||||
| out << "Connection Id: " << rq_.connection_id(); | |||||
| } | |||||
| out << " "; | |||||
| BaseRequest::Print(out); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,103 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/cache/cache_common.h" | |||||
| #include "minddata/dataset/engine/cache/cache_arena.h" | |||||
| #include "minddata/dataset/util/allocator.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| #include "minddata/dataset/util/task_manager.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| /// \brief Server side view of BaseRequest. Incoming request are in the form of protobuf objects | |||||
| /// and this class is used to translate from protobuf to structures understood by CacheService class. | |||||
| /// \see CacheService | |||||
| class CacheServerRequest : public BaseRequest { | |||||
| public: | |||||
| friend class CacheServer; | |||||
| enum class STATE : int8_t { CREATE = 1, PROCESS = 2, FINISH = 3 }; | |||||
| explicit CacheServerRequest(int32_t queue_id) | |||||
| : BaseRequest::BaseRequest(BaseRequest::RequestType::kRequestUnknown), | |||||
| qid_(queue_id), | |||||
| st_(STATE::CREATE), | |||||
| responder_(&ctx_) {} | |||||
| ~CacheServerRequest() = default; | |||||
| /// \brief Functor. Used mainly by CacheServerGreeterImpl class to tag each incoming request and this | |||||
| /// functor will translate each protobuf into some form understood by by CacheService class. | |||||
| /// \param svc Async service | |||||
| /// \param cq Completion queue | |||||
| /// \return Status object | |||||
| Status operator()(CacheServerGreeter::AsyncService *svc, grpc::ServerCompletionQueue *cq); | |||||
| /// \brief Override the base class Print method | |||||
| /// \param out | |||||
| void Print(std::ostream &out) const override; | |||||
| /// \brief Getter of the queue id | |||||
| /// \return The queue where the request should go to | |||||
| int32_t getQid() const { return qid_; } | |||||
| private: | |||||
| int32_t qid_; | |||||
| Status rc_; | |||||
| STATE st_; | |||||
| grpc::ServerContext ctx_; | |||||
| grpc::ServerAsyncResponseWriter<CacheReply> responder_; | |||||
| }; | |||||
| /// \brief Implementation of CacheServerGreeter | |||||
| /// \note It is an async server | |||||
| /// \see cache_grpc.proto | |||||
| class CacheServerGreeterImpl final { | |||||
| friend class CacheServer; | |||||
| public: | |||||
| explicit CacheServerGreeterImpl(int32_t port, int32_t shared_memory_sz_in_gb); | |||||
| virtual ~CacheServerGreeterImpl(); | |||||
| /// \brief Brings up gRPC server | |||||
| /// \return none | |||||
| Status Run(); | |||||
| /// \brief Entry function to handle cache server request | |||||
| Status HandleRequest(int32_t worker_id); | |||||
| /// Return the shared memory pool. | |||||
| /// \return Return the shared memory pool | |||||
| CachedSharedMemoryArena *GetSharedMemoryPool() { return shm_pool_.get(); } | |||||
| void Shutdown(); | |||||
| Status IpcResourceCleanup(); | |||||
| private: | |||||
| int32_t port_; | |||||
| size_t shm_pool_sz_in_gb_; | |||||
| std::string unix_socket_; | |||||
| CacheServerGreeter::AsyncService svc_; | |||||
| std::unique_ptr<grpc::ServerCompletionQueue> cq_; | |||||
| std::unique_ptr<grpc::Server> server_; | |||||
| std::unique_ptr<CachedSharedMemoryArena> shm_pool_; | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_ | |||||
| @@ -0,0 +1,121 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/cache/cache_server.h" | |||||
| #include <sys/types.h> | |||||
| #include <unistd.h> | |||||
| #ifdef USE_GLOG | |||||
| #include <glog/logging.h> | |||||
| #endif | |||||
| #include <cstdlib> | |||||
| namespace ds = mindspore::dataset; | |||||
| int main(int argc, char **argv) { | |||||
| ds::Status rc; | |||||
| ds::CacheServer::Builder builder; | |||||
| // This executable is not to be called directly, and should be invoked by cache_admin executable. | |||||
| if (argc != 7) { | |||||
| rc = ds::Status(ds::StatusCode::kSyntaxError); | |||||
| std::cerr << rc.ToString() << std::endl; | |||||
| return static_cast<int>(rc.get_code()); | |||||
| } | |||||
| builder.SetRootDirectory(argv[1]) | |||||
| .SetNumWorkers(strtol(argv[2], nullptr, 10)) | |||||
| .SetPort(strtol(argv[3], nullptr, 10)) | |||||
| .SetSharedMemorySizeInGB(strtol(argv[4], nullptr, 10)); | |||||
| #ifdef USE_GLOG | |||||
| FLAGS_minloglevel = strtol(argv[5], nullptr, 10); | |||||
| #endif | |||||
| auto daemonize_string = argv[6]; | |||||
| bool daemonize = strcmp(daemonize_string, "true") == 0 || strcmp(daemonize_string, "TRUE") == 0 || | |||||
| strcmp(daemonize_string, "t") == 0 || strcmp(daemonize_string, "T") == 0; | |||||
| // We always change directory to / on unix rather than using the directory where the cache_server | |||||
| // is called. This is a standard procedure for daemonize a process on unix. | |||||
| if (chdir("/") == -1) { | |||||
| std::string errMsg = "Unable to change directory to /. Errno = " + std::to_string(errno); | |||||
| std::cerr << errMsg << std::endl; | |||||
| return -1; | |||||
| } | |||||
| // Simple check of the parameters before we move on. | |||||
| rc = builder.SanityCheck(); | |||||
| if (rc.IsError()) { | |||||
| std::cerr << rc.ToString() << std::endl; | |||||
| return static_cast<int>(rc.get_code()); | |||||
| } | |||||
| #ifdef USE_GLOG | |||||
| FLAGS_log_dir = "/tmp"; | |||||
| google::InitGoogleLogging(argv[0]); | |||||
| #endif | |||||
| if (daemonize) { | |||||
| // fork the child process to become the daemon | |||||
| pid_t pid = fork(); | |||||
| // failed to fork | |||||
| if (pid < 0) { | |||||
| std::string err_msg = "Failed to fork process for cache server: " + std::to_string(errno); | |||||
| std::cerr << err_msg << std::endl; | |||||
| return errno; | |||||
| } else if (pid > 0) { | |||||
| // Parent | |||||
| std::cerr << "cache server daemon process has been created as process id: " << pid | |||||
| << "\nCheck log file for any start up error" << std::endl; | |||||
| signal(SIGCHLD, SIG_IGN); // ignore sig child signal. | |||||
| return 0; | |||||
| } else { | |||||
| // Child process will continue from here if daemonize and parent has already exited. | |||||
| // If we are running in the foreground, none of the code in block below will be run. | |||||
| pid_t sid; | |||||
| umask(0); | |||||
| sid = setsid(); | |||||
| if (sid < 0) { | |||||
| MS_LOG(ERROR) << "Failed to setsid(). Errno = " << std::to_string(errno); | |||||
| return errno; | |||||
| } | |||||
| close(0); | |||||
| close(1); | |||||
| close(2); | |||||
| } | |||||
| } | |||||
| // Dump the summary | |||||
| MS_LOG(INFO) << builder << std::endl; | |||||
| rc = builder.Build(); | |||||
| if (rc.IsOk()) { | |||||
| ds::CacheServer &cs = ds::CacheServer::GetInstance(); | |||||
| // Kick off the threads. Loop forever and never return unless error. | |||||
| rc = cs.Run(); | |||||
| if (rc.get_code() == ds::StatusCode::kDuplicateKey) { | |||||
| std::string errMsg = "Server is already started"; | |||||
| MS_LOG(ERROR) << errMsg; | |||||
| std::cerr << errMsg << std::endl; | |||||
| return 0; | |||||
| } | |||||
| } | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << rc.ToString(); | |||||
| std::cerr << rc.ToString() << std::endl; | |||||
| return static_cast<int>(rc.get_code()); | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| @@ -14,154 +14,149 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "minddata/dataset/engine/cache/cache_request.h" | #include "minddata/dataset/engine/cache/cache_request.h" | ||||
| #include <cstdlib> | |||||
| #include <thread> | |||||
| #include "minddata/dataset/core/constants.h" | |||||
| #include "minddata/dataset/engine/cache/cache_client.h" | |||||
| #include "minddata/dataset/engine/cache/cache_fbb.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| Status CacheRowRequest::SerializeCacheRowRequest(const TensorRow &row) { | |||||
| buffers_.reserve(row.size() + 1); | |||||
| RETURN_IF_NOT_OK(SerializeTensorRowHeader(row)); | |||||
| buffers_.push_back(fbb_->GetBufferPointer()); | |||||
| for (const auto &ts : row) { | |||||
| buffers_.push_back(ts->GetBuffer()); | |||||
| } | |||||
| Status BaseRequest::Wait() { | |||||
| RETURN_IF_NOT_OK(wp_.Wait()); | |||||
| Status remote_rc(static_cast<StatusCode>(reply_.rc()), reply_.msg()); | |||||
| RETURN_IF_NOT_OK(remote_rc); | |||||
| // Any extra work to do before we return back to the client. | |||||
| RETURN_IF_NOT_OK(PostReply()); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheRowRequest::SerializeTensorRowHeader(const TensorRow &row) { | |||||
| try { | |||||
| fbb_ = std::make_shared<flatbuffers::FlatBufferBuilder>(); | |||||
| std::vector<flatbuffers::Offset<TensorMetaMsg>> v; | |||||
| std::vector<int64_t> tensor_sz; | |||||
| v.reserve(row.size()); | |||||
| tensor_sz.reserve(row.size()); | |||||
| // We will go through each column in the row. | |||||
| for (const std::shared_ptr<Tensor> &ts_ptr : row) { | |||||
| flatbuffers::Offset<TensorMetaMsg> ts_off; | |||||
| RETURN_IF_NOT_OK(SerializeOneTensorMeta(ts_ptr, &ts_off)); | |||||
| v.push_back(ts_off); | |||||
| tensor_sz.push_back(ts_ptr->SizeInBytes()); | |||||
| Status CacheRowRequest::SerializeCacheRowRequest(const CacheClient *cc, const TensorRow &row) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(row.size() > 0, "Empty tensor row"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(cc->SupportLocalClient() == support_local_bypass_, "Local bypass mismatch"); | |||||
| // Calculate how many bytes (not counting the cookie) we are sending to the server. We only | |||||
| // use shared memory (if supported) if we exceed certain amount | |||||
| std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb; | |||||
| RETURN_IF_NOT_OK(::mindspore::dataset::SerializeTensorRowHeader(row, &fbb)); | |||||
| sz_ += fbb->GetSize(); | |||||
| for (const auto &ts : row) { | |||||
| sz_ += ts->SizeInBytes(); | |||||
| } | |||||
| bool sent_using_local_bypass = support_local_bypass_ ? (sz_ >= kLocalByPassThreshold) : false; | |||||
| uint32_t flag = 0; | |||||
| if (support_local_bypass_) { | |||||
| BitSet(&flag, kLocalClientSupport); | |||||
| } | |||||
| if (sent_using_local_bypass) { | |||||
| BitSet(&flag, kDataIsInSharedMemory); | |||||
| } | |||||
| rq_.set_flag(flag); | |||||
| if (sent_using_local_bypass) { | |||||
| MS_LOG(DEBUG) << "Requesting " << sz_ << " bytes of shared memory data"; | |||||
| // Allocate shared memory from the server | |||||
| auto mem_rq = std::make_shared<AllocateSharedBlockRequest>(rq_.connection_id(), sz_); | |||||
| RETURN_IF_NOT_OK(cc->PushRequest(mem_rq)); | |||||
| RETURN_IF_NOT_OK(mem_rq->Wait()); | |||||
| addr_ = mem_rq->GetAddr(); | |||||
| // Now we need to add that to the base address of where we attach. | |||||
| auto base = cc->SharedMemoryBaseAddr(); | |||||
| auto p = reinterpret_cast<void *>(reinterpret_cast<int64_t>(base) + addr_); | |||||
| // Now we copy the data onto shared memory. | |||||
| WritableSlice all(p, sz_); | |||||
| auto offset = fbb->GetSize(); | |||||
| ReadableSlice header(fbb->GetBufferPointer(), fbb->GetSize()); | |||||
| Status copy_rc; | |||||
| copy_rc = WritableSlice::Copy(&all, header); | |||||
| if (copy_rc.IsOk()) { | |||||
| for (const auto &ts : row) { | |||||
| WritableSlice row_data(all, offset, ts->SizeInBytes()); | |||||
| ReadableSlice src(ts->GetBuffer(), ts->SizeInBytes()); | |||||
| copy_rc = WritableSlice::Copy(&row_data, src); | |||||
| if (copy_rc.IsError()) { | |||||
| break; | |||||
| } | |||||
| offset += ts->SizeInBytes(); | |||||
| } | |||||
| // Fill in where to find the data | |||||
| AddDataLocation(); | |||||
| } | } | ||||
| auto column_off = fbb_->CreateVector(v); | |||||
| auto data_sz_off = fbb_->CreateVector(tensor_sz); | |||||
| TensorRowHeaderMsgBuilder row_builder(*fbb_); | |||||
| row_builder.add_column(column_off); | |||||
| row_builder.add_data_sz(data_sz_off); | |||||
| // Pass the row_id even if it may not be known. | |||||
| row_builder.add_row_id(row.getId()); | |||||
| row_builder.add_size_of_this(-1); // fill in later after we call Finish. | |||||
| auto out = row_builder.Finish(); | |||||
| fbb_->Finish(out); | |||||
| // Now go back to fill in size_of_this in the flat buffer. | |||||
| auto msg = GetMutableTensorRowHeaderMsg(fbb_->GetBufferPointer()); | |||||
| auto success = msg->mutate_size_of_this(fbb_->GetSize()); | |||||
| if (!success) { | |||||
| RETURN_STATUS_UNEXPECTED("Unable to set size_of_this"); | |||||
| if (copy_rc.IsError()) { | |||||
| // We need to return the memory back to the server | |||||
| auto mfree_req = GenerateFreeBlockRequest(); | |||||
| Status rc = cc->PushRequest(mfree_req); | |||||
| // But we won't wait for the result for the sake of performance. | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "Push request for free memory failed."; | |||||
| } | |||||
| return copy_rc; | |||||
| } | } | ||||
| return Status::OK(); | |||||
| } catch (const std::bad_alloc &e) { | |||||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); | |||||
| } else { | |||||
| // We have already filled the first buffer which is the cookie. | |||||
| sz_ += rq_.buf_data(0).size(); | |||||
| rq_.add_buf_data(fbb->GetBufferPointer(), fbb->GetSize()); | |||||
| for (const auto &ts : row) { | |||||
| rq_.add_buf_data(ts->GetBuffer(), ts->SizeInBytes()); | |||||
| } | |||||
| MS_LOG(DEBUG) << "Sending " << sz_ << " bytes of tensor data in " << rq_.buf_data_size() << " segments"; | |||||
| } | } | ||||
| return Status::OK(); | |||||
| } | } | ||||
| Status CacheRowRequest::SerializeOneTensorMeta(const std::shared_ptr<Tensor> &ts_ptr, | |||||
| flatbuffers::Offset<TensorMetaMsg> *out_off) { | |||||
| RETURN_UNEXPECTED_IF_NULL(out_off); | |||||
| const Tensor *ts = ts_ptr.get(); | |||||
| auto shape_off = fbb_->CreateVector(ts->shape().AsVector()); | |||||
| const auto ptr = ts->GetBuffer(); | |||||
| if (ptr == nullptr) { | |||||
| RETURN_STATUS_UNEXPECTED("Tensor buffer is null"); | |||||
| } | |||||
| auto src = ts->type().value(); | |||||
| TensorType dest; | |||||
| #define CASE(t) \ | |||||
| case DataType::t: \ | |||||
| dest = TensorType::TensorType_##t; \ | |||||
| break | |||||
| // Map the type to fill in the flat buffer. | |||||
| switch (src) { | |||||
| CASE(DE_BOOL); | |||||
| CASE(DE_INT8); | |||||
| CASE(DE_UINT8); | |||||
| CASE(DE_INT16); | |||||
| CASE(DE_UINT16); | |||||
| CASE(DE_INT32); | |||||
| CASE(DE_UINT32); | |||||
| CASE(DE_INT64); | |||||
| CASE(DE_UINT64); | |||||
| CASE(DE_FLOAT16); | |||||
| CASE(DE_FLOAT32); | |||||
| CASE(DE_FLOAT64); | |||||
| CASE(DE_STRING); | |||||
| default: | |||||
| MS_LOG(ERROR) << "Unknown tensor. Dumping content:\n" << *ts; | |||||
| RETURN_STATUS_UNEXPECTED("Unknown type"); | |||||
| Status CacheRowRequest::PostReply() { | |||||
| if (!reply_.result().empty()) { | |||||
| row_id_from_server_ = strtoll(reply_.result().data(), nullptr, 10); | |||||
| } | } | ||||
| #undef CASE | |||||
| TensorMetaMsgBuilder ts_builder(*fbb_); | |||||
| ts_builder.add_dims(shape_off); | |||||
| ts_builder.add_type(dest); | |||||
| auto ts_off = ts_builder.Finish(); | |||||
| *out_off = ts_off; | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status BatchFetchRequest::RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, | |||||
| std::shared_ptr<Tensor> *out) { | |||||
| RETURN_UNEXPECTED_IF_NULL(col_ts); | |||||
| auto shape_in = col_ts->dims(); | |||||
| auto type_in = col_ts->type(); | |||||
| std::vector<dsize_t> v; | |||||
| v.reserve(shape_in->size()); | |||||
| v.assign(shape_in->begin(), shape_in->end()); | |||||
| TensorShape shape(v); | |||||
| DataType::Type dest = DataType::DE_UNKNOWN; | |||||
| #define CASE(t) \ | |||||
| case TensorType_##t: \ | |||||
| dest = DataType::Type::t; \ | |||||
| break | |||||
| switch (type_in) { | |||||
| CASE(DE_BOOL); | |||||
| CASE(DE_INT8); | |||||
| CASE(DE_UINT8); | |||||
| CASE(DE_INT16); | |||||
| CASE(DE_UINT16); | |||||
| CASE(DE_INT32); | |||||
| CASE(DE_UINT32); | |||||
| CASE(DE_INT64); | |||||
| CASE(DE_UINT64); | |||||
| CASE(DE_FLOAT16); | |||||
| CASE(DE_FLOAT32); | |||||
| CASE(DE_FLOAT64); | |||||
| CASE(DE_STRING); | |||||
| Status CacheRowRequest::Prepare() { | |||||
| if (BitTest(rq_.flag(), kDataIsInSharedMemory)) { | |||||
| // First one is cookie, followed by address and then size. | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(rq_.buf_data_size() == 3, "Incomplete rpc data"); | |||||
| } else { | |||||
| // First one is cookie. 2nd one is the google flat buffers followed by a number of buffers. | |||||
| // But we are not going to decode them to verify. | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(rq_.buf_data_size() >= 3, "Incomplete rpc data"); | |||||
| } | } | ||||
| #undef CASE | |||||
| DataType type(dest); | |||||
| std::shared_ptr<Tensor> ts; | |||||
| RETURN_IF_NOT_OK( | |||||
| Tensor::CreateFromMemory(shape, type, static_cast<const unsigned char *>(data.GetPointer()), data.GetSize(), &ts)); | |||||
| // Next we restore the real data which can be embedded or stored separately. | |||||
| if (ts->SizeInBytes() != data.GetSize()) { | |||||
| MS_LOG(ERROR) << "Unexpected length. Read " << data.GetSize() << ". Expected " << ts->SizeInBytes() << ".\n" | |||||
| << "Dumping tensor\n" | |||||
| << *ts << "\n"; | |||||
| RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details."); | |||||
| } | |||||
| *out = std::move(ts); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status BatchFetchRequest::RestoreRows(TensorTable *out) { | |||||
| BatchFetchRequest::BatchFetchRequest(connection_id_type connection_id, const std::vector<row_id_type> &row_id, | |||||
| bool local_bypass) | |||||
| : BaseRequest(RequestType::kBatchFetchRows), support_local_bypass_(local_bypass), row_id_(row_id) { | |||||
| rq_.set_connection_id(connection_id); | |||||
| rq_.set_flag(support_local_bypass_ ? kLocalClientSupport : 0); | |||||
| // Convert the row id into a flatbuffer | |||||
| flatbuffers::FlatBufferBuilder fbb; | |||||
| auto off_t = fbb.CreateVector(row_id); | |||||
| TensorRowIdsBuilder bld(fbb); | |||||
| bld.add_row_id(off_t); | |||||
| auto off = bld.Finish(); | |||||
| fbb.Finish(off); | |||||
| rq_.add_buf_data(fbb.GetBufferPointer(), fbb.GetSize()); | |||||
| } | |||||
| Status BatchFetchRequest::RestoreRows(TensorTable *out, const void *baseAddr, int64_t *out_addr) { | |||||
| RETURN_UNEXPECTED_IF_NULL(out); | RETURN_UNEXPECTED_IF_NULL(out); | ||||
| auto num_elements = row_id_.size(); | auto num_elements = row_id_.size(); | ||||
| auto *offset_array = reinterpret_cast<const int64_t *>(mem_.GetPointer()); | |||||
| const char *ptr = nullptr; | |||||
| int64_t sz = 0; | |||||
| // Tap into the reply flag to see where we can find the data. Server may decide the amount is | |||||
| // so small that it doesn't use shared memory method. | |||||
| auto flag = reply_.flag(); | |||||
| bool dataOnSharedMemory = support_local_bypass_ ? (BitTest(flag, kDataIsInSharedMemory)) : false; | |||||
| if (dataOnSharedMemory) { | |||||
| auto addr = strtoll(reply_.result().data(), nullptr, 10); | |||||
| ptr = reinterpret_cast<const char *>(reinterpret_cast<int64_t>(baseAddr) + addr); | |||||
| RETURN_UNEXPECTED_IF_NULL(out); | |||||
| *out_addr = addr; | |||||
| } else { | |||||
| ptr = reply_.result().data(); | |||||
| *out_addr = -1; | |||||
| } | |||||
| auto *offset_array = reinterpret_cast<const int64_t *>(ptr); | |||||
| sz = offset_array[num_elements]; | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(support_local_bypass_ || sz == reply_.result().length(), "Length mismatch"); | |||||
| TensorTable tbl; | TensorTable tbl; | ||||
| tbl.reserve(num_elements); | tbl.reserve(num_elements); | ||||
| ReadableSlice all(mem_.GetPointer(), mem_.GetSizeInBytes()); | |||||
| ReadableSlice all(ptr, sz); | |||||
| for (auto i = 0; i < num_elements; ++i) { | for (auto i = 0; i < num_elements; ++i) { | ||||
| auto len = offset_array[i + 1] - offset_array[i]; | auto len = offset_array[i + 1] - offset_array[i]; | ||||
| TensorRow row; | TensorRow row; | ||||
| @@ -178,10 +173,12 @@ Status BatchFetchRequest::RestoreRows(TensorTable *out) { | |||||
| auto col_ts = msg->column()->Get(k); | auto col_ts = msg->column()->Get(k); | ||||
| std::shared_ptr<Tensor> ts; | std::shared_ptr<Tensor> ts; | ||||
| ReadableSlice data(row_data, ts_offset, msg->data_sz()->Get(k)); | ReadableSlice data(row_data, ts_offset, msg->data_sz()->Get(k)); | ||||
| RETURN_IF_NOT_OK(RestoreOneTensor(col_ts, data, &ts)); | |||||
| RETURN_IF_NOT_OK(mindspore::dataset::RestoreOneTensor(col_ts, data, &ts)); | |||||
| row.push_back(ts); | row.push_back(ts); | ||||
| ts_offset += data.GetSize(); | ts_offset += data.GetSize(); | ||||
| } | } | ||||
| } else { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(len == 0, "Data corruption detected."); | |||||
| } | } | ||||
| tbl.push_back(std::move(row)); | tbl.push_back(std::move(row)); | ||||
| } | } | ||||
| @@ -189,36 +186,69 @@ Status BatchFetchRequest::RestoreRows(TensorTable *out) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| CreateCacheRequest::CreateCacheRequest(const CacheClientInfo &cinfo, uint64_t cache_mem_sz, | |||||
| CreateCacheRequest::CreateCacheFlag flag) | |||||
| : BaseRequest(RequestType::kCreateCache), cache_mem_sz_(cache_mem_sz), flag_(flag) { | |||||
| // Type has been set already in the base constructor. So we need to fill in the connection info. | |||||
| // On successful return, we will get the connection id | |||||
| rq_.mutable_connection_info()->operator=(cinfo); | |||||
| } | |||||
| Status CreateCacheRequest::Prepare() { | |||||
| try { | |||||
| flatbuffers::FlatBufferBuilder fbb; | |||||
| CreateCacheRequestMsgBuilder bld(fbb); | |||||
| bld.add_cache_mem_sz(cache_mem_sz_); | |||||
| bld.add_flag(static_cast<uint32_t>(flag_)); | |||||
| auto off = bld.Finish(); | |||||
| fbb.Finish(off); | |||||
| rq_.add_buf_data(fbb.GetBufferPointer(), fbb.GetSize()); | |||||
| return Status::OK(); | |||||
| } catch (const std::bad_alloc &e) { | |||||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); | |||||
| } | |||||
| } | |||||
| Status CacheSchemaRequest::SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map) { | Status CacheSchemaRequest::SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map) { | ||||
| try { | try { | ||||
| fbb_ = std::make_shared<flatbuffers::FlatBufferBuilder>(); | |||||
| flatbuffers::FlatBufferBuilder fbb; | |||||
| std::vector<flatbuffers::Offset<ColumnNameMsg>> v; | std::vector<flatbuffers::Offset<ColumnNameMsg>> v; | ||||
| v.reserve(map.size()); | v.reserve(map.size()); | ||||
| for (auto &column : map) { | for (auto &column : map) { | ||||
| auto c = CreateColumnNameMsg(*fbb_, fbb_->CreateString(column.first), column.second); | |||||
| auto c = CreateColumnNameMsg(fbb, fbb.CreateString(column.first), column.second); | |||||
| v.push_back(c); | v.push_back(c); | ||||
| } | } | ||||
| auto v_off = fbb_->CreateVector(v); | |||||
| auto final_off = CreateSchemaMsg(*fbb_, v_off); | |||||
| fbb_->Finish(final_off); | |||||
| buf_ = fbb_->GetBufferPointer(); | |||||
| len_of_buf_ = fbb_->GetSize(); | |||||
| auto v_off = fbb.CreateVector(v); | |||||
| auto final_off = CreateSchemaMsg(fbb, v_off); | |||||
| fbb.Finish(final_off); | |||||
| rq_.add_buf_data(fbb.GetBufferPointer(), fbb.GetSize()); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } catch (const std::bad_alloc &e) { | } catch (const std::bad_alloc &e) { | ||||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); | return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); | ||||
| } | } | ||||
| } | } | ||||
| std::unordered_map<std::string, int32_t> FetchSchemaRequest::GetColumnMap() { | |||||
| if (column_name_id_map_.empty()) { | |||||
| auto *map_msg = flatbuffers::GetRoot<SchemaMsg>(mem_.GetPointer()); | |||||
| auto v = map_msg->column(); | |||||
| for (auto i = 0; i < v->size(); ++i) { | |||||
| auto col = map_msg->column()->Get(i); | |||||
| column_name_id_map_.emplace(col->name()->str(), col->id()); | |||||
| } | |||||
| Status FetchSchemaRequest::PostReply() { | |||||
| auto *map_msg = flatbuffers::GetRoot<SchemaMsg>(reply_.result().data()); | |||||
| auto v = map_msg->column(); | |||||
| for (auto i = 0; i < v->size(); ++i) { | |||||
| auto col = map_msg->column()->Get(i); | |||||
| column_name_id_map_.emplace(col->name()->str(), col->id()); | |||||
| } | } | ||||
| return column_name_id_map_; | |||||
| return Status::OK(); | |||||
| } | |||||
| std::unordered_map<std::string, int32_t> FetchSchemaRequest::GetColumnMap() { return column_name_id_map_; } | |||||
| Status GetStatRequest::PostReply() { | |||||
| auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(reply_.result().data()); | |||||
| stat_.num_disk_cached = msg->num_disk_cached(); | |||||
| stat_.num_mem_cached = msg->num_mem_cached(); | |||||
| stat_.avg_cache_sz = msg->avg_cache_sz(); | |||||
| stat_.max_row_id = msg->max_row_id(); | |||||
| stat_.min_row_id = msg->min_row_id(); | |||||
| stat_.cache_service_state = msg->state(); | |||||
| return Status::OK(); | |||||
| } | } | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -18,11 +18,16 @@ | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <memory> | #include <memory> | ||||
| #include <iostream> | |||||
| #include <string> | #include <string> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #ifdef ENABLE_CACHE | |||||
| #include "proto/cache_grpc.grpc.pb.h" | |||||
| #endif | |||||
| #include "proto/cache_grpc.pb.h" | |||||
| #include "minddata/dataset/core/tensor_row.h" | #include "minddata/dataset/core/tensor_row.h" | ||||
| #include "minddata/dataset/engine/cache/de_tensor_generated.h" | #include "minddata/dataset/engine/cache/de_tensor_generated.h" | ||||
| #include "minddata/dataset/util/slice.h" | #include "minddata/dataset/util/slice.h" | ||||
| @@ -30,6 +35,17 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| class CacheClient; | |||||
| /// \brief Statistic structure for GetStat request | |||||
| struct CacheServiceStat { | |||||
| int64_t num_mem_cached; | |||||
| int64_t num_disk_cached; | |||||
| int64_t avg_cache_sz; | |||||
| row_id_type min_row_id; | |||||
| row_id_type max_row_id; | |||||
| int8_t cache_service_state; | |||||
| }; | |||||
| /// \brief CacheClient communicates with CacheServer using Requests. | /// \brief CacheClient communicates with CacheServer using Requests. | ||||
| class BaseRequest { | class BaseRequest { | ||||
| public: | public: | ||||
| @@ -44,195 +60,301 @@ class BaseRequest { | |||||
| kCacheSchema = 6, | kCacheSchema = 6, | ||||
| kFetchSchema = 7, | kFetchSchema = 7, | ||||
| kBuildPhaseDone = 8, | kBuildPhaseDone = 8, | ||||
| kDropSession = 9, | |||||
| kGenerateSessionId = 10, | |||||
| kAllocateSharedBlock = 11, | |||||
| kFreeSharedBlock = 12, | |||||
| kStopService = 13, | |||||
| // Add new request before it. | // Add new request before it. | ||||
| kRequestUnknown = 32767 | kRequestUnknown = 32767 | ||||
| }; | }; | ||||
| // For kCreateCache | |||||
| enum class CreateCacheFlag : uint32_t { kNone = 0, kSpillToDisk = 1, kGenerateRowId = 1u << 1L }; | |||||
| friend class CacheServer; | friend class CacheServer; | ||||
| friend class CacheServerRequest; | |||||
| friend class CacheClientGreeter; | |||||
| friend class CacheClientRequestTag; | |||||
| /// \brief Base class of a cache server request | /// \brief Base class of a cache server request | ||||
| /// \param connection_id A combination of session id and crc that uniquely identifies a connection. | |||||
| /// \param type Type of the request | /// \param type Type of the request | ||||
| explicit BaseRequest(connection_id_type connection_id, RequestType type) | |||||
| : type_(type), connection_id_(connection_id) {} | |||||
| explicit BaseRequest(RequestType type) : type_(type) { rq_.set_type(static_cast<google::int32>(type_)); } | |||||
| virtual ~BaseRequest() = default; | virtual ~BaseRequest() = default; | ||||
| /// \brief Wait for the completion of a request | |||||
| /// \return Status returned from the cache server | |||||
| Status Wait() { | |||||
| RETURN_IF_NOT_OK(wp_.Wait()); | |||||
| return rc_; | |||||
| /// \brief A print method for debugging | |||||
| /// \param out The output stream to write output to | |||||
| virtual void Print(std::ostream &out) const { out << "Request type: " << static_cast<int16_t>(type_); } | |||||
| /// \brief << Stream output operator overload | |||||
| /// \param out reference to the output stream | |||||
| /// \param rq reference to the BaseRequest | |||||
| /// \return the output stream | |||||
| friend std::ostream &operator<<(std::ostream &out, const BaseRequest &rq) { | |||||
| rq.Print(out); | |||||
| return out; | |||||
| } | } | ||||
| /// \brief Getter function of the current connection id | |||||
| /// \return Connection id | |||||
| connection_id_type GetServerConnectionId() const { return connection_id_; } | |||||
| /// \brief Derived class can implement extra work to be done before the request is sent to the server | |||||
| virtual Status Prepare() { return Status::OK(); } | |||||
| /// \brief Derived class can implement extra work to be done after the server sends the request | |||||
| virtual Status PostReply() { return Status::OK(); } | |||||
| /// \brief A method for the client to wait for the availability of the result back from the server. | |||||
| /// \return Status object | |||||
| Status Wait(); | |||||
| protected: | |||||
| CacheRequest rq_; // This is what we send to the server | |||||
| CacheReply reply_; // This is what the server send back | |||||
| private: | private: | ||||
| RequestType type_; | RequestType type_; | ||||
| connection_id_type connection_id_; | |||||
| Status rc_; | |||||
| WaitPost wp_; | |||||
| WaitPost wp_; // A sync area used by the client side. | |||||
| }; | }; | ||||
| class FreeSharedBlockRequest : public BaseRequest { | |||||
| public: | |||||
| friend class CacheServer; | |||||
| explicit FreeSharedBlockRequest(connection_id_type connection_id, int64_t addr) | |||||
| : BaseRequest(RequestType::kFreeSharedBlock) { | |||||
| rq_.set_connection_id(connection_id); | |||||
| rq_.add_buf_data(std::to_string(addr)); | |||||
| } | |||||
| ~FreeSharedBlockRequest() = default; | |||||
| }; | |||||
| /// \brief Request to cache a single TensorRow | /// \brief Request to cache a single TensorRow | ||||
| class CacheRowRequest : public BaseRequest { | class CacheRowRequest : public BaseRequest { | ||||
| public: | public: | ||||
| friend class CacheServer; | friend class CacheServer; | ||||
| explicit CacheRowRequest(connection_id_type connection_id, const std::string &cookie) | |||||
| : BaseRequest(connection_id, RequestType::kCacheRow), row_id_from_server_(-1), cookie_(cookie) {} | |||||
| friend class CacheClient; | |||||
| explicit CacheRowRequest(connection_id_type connection_id, const std::string &cookie, bool local_bypass) | |||||
| : BaseRequest(RequestType::kCacheRow), | |||||
| support_local_bypass_(local_bypass), | |||||
| addr_(-1), | |||||
| sz_(0), | |||||
| row_id_from_server_(-1) { | |||||
| rq_.set_connection_id(connection_id); | |||||
| rq_.add_buf_data(cookie); | |||||
| } | |||||
| ~CacheRowRequest() = default; | ~CacheRowRequest() = default; | ||||
| /// \brief Serialize a TensorRow for streaming to the cache server | /// \brief Serialize a TensorRow for streaming to the cache server | ||||
| /// \param row TensorRow | /// \param row TensorRow | ||||
| /// \return Status object | /// \return Status object | ||||
| Status SerializeCacheRowRequest(const TensorRow &row); | |||||
| Status SerializeCacheRowRequest(const CacheClient *cc, const TensorRow &row); | |||||
| /// \brief Sanity check before we send the row. | |||||
| /// \return Status object | |||||
| Status Prepare() override; | |||||
| /// \brief Override the base function get the row id returned from the server | |||||
| /// \return Status object | |||||
| Status PostReply() override; | |||||
| /// \brief Return the row id assigned to this row for non-mappable dataset | /// \brief Return the row id assigned to this row for non-mappable dataset | ||||
| /// \return row id of the cached row | /// \return row id of the cached row | ||||
| row_id_type GetRowIdAfterCache() { return row_id_from_server_; } | row_id_type GetRowIdAfterCache() { return row_id_from_server_; } | ||||
| /// \brief If we are doing local bypass, fill in extra request information of where the data is located. | |||||
| void AddDataLocation() { | |||||
| if (support_local_bypass_) { | |||||
| rq_.add_buf_data(std::to_string(addr_)); | |||||
| rq_.add_buf_data(std::to_string(sz_)); | |||||
| } | |||||
| } | |||||
| /// \brief If we fail to send the data to the server using shared memory method, we should release | |||||
| /// the shared memory by sending another request. The following function will generate a suitable | |||||
| /// request for the CacheClient to send. | |||||
| std::shared_ptr<FreeSharedBlockRequest> GenerateFreeBlockRequest() { | |||||
| return std::make_shared<FreeSharedBlockRequest>(rq_.connection_id(), addr_); | |||||
| } | |||||
| private: | private: | ||||
| std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb_; | |||||
| bool support_local_bypass_; | |||||
| int64_t addr_; | |||||
| int64_t sz_; | |||||
| row_id_type row_id_from_server_; | row_id_type row_id_from_server_; | ||||
| std::vector<const void *> buffers_; | |||||
| std::string cookie_; | |||||
| /// \brief Private function to serialize one TensorRow | |||||
| /// \param row TensorRow | |||||
| /// \return Status object | |||||
| Status SerializeTensorRowHeader(const TensorRow &row); | |||||
| /// \brief Private function to serialize one Tensor | |||||
| /// \param ts_ptr Tensor | |||||
| /// \return Status object | |||||
| Status SerializeOneTensorMeta(const std::shared_ptr<Tensor> &ts_ptr, flatbuffers::Offset<TensorMetaMsg> *out_off); | |||||
| }; | }; | ||||
| /// \brief Request to fetch rows in batch | /// \brief Request to fetch rows in batch | ||||
| class BatchFetchRequest : public BaseRequest { | class BatchFetchRequest : public BaseRequest { | ||||
| public: | public: | ||||
| friend class CacheServer; | friend class CacheServer; | ||||
| friend class CacheService; | friend class CacheService; | ||||
| BatchFetchRequest(connection_id_type connection_id, const std::vector<row_id_type> &row_id) | |||||
| : BaseRequest(connection_id, RequestType::kBatchFetchRows), row_id_(row_id) {} | |||||
| BatchFetchRequest(connection_id_type connection_id, const std::vector<row_id_type> &row_id, bool local_bypass); | |||||
| ~BatchFetchRequest() = default; | ~BatchFetchRequest() = default; | ||||
| Status RestoreRows(TensorTable *out); | |||||
| Status RestoreRows(TensorTable *out, const void *baseAddr, int64_t *out_addr); | |||||
| private: | private: | ||||
| bool support_local_bypass_; | |||||
| std::vector<row_id_type> row_id_; | std::vector<row_id_type> row_id_; | ||||
| MemGuard<uint8_t> mem_; | |||||
| Status RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, std::shared_ptr<Tensor> *out); | |||||
| }; | }; | ||||
| /// \brief Request to create a cache for the current connection | /// \brief Request to create a cache for the current connection | ||||
| class CreationCacheRequest : public BaseRequest { | |||||
| class CreateCacheRequest : public BaseRequest { | |||||
| public: | public: | ||||
| friend class CacheServer; | friend class CacheServer; | ||||
| enum class CreateCacheFlag : uint32_t { kNone = 0, kSpillToDisk = 1, kGenerateRowId = 1u << 1L }; | |||||
| /// \brief Constructor | /// \brief Constructor | ||||
| /// \param connection_id | /// \param connection_id | ||||
| /// \param cache_mem_sz Maximum memory assigned for this connection. 0 means unlimited | /// \param cache_mem_sz Maximum memory assigned for this connection. 0 means unlimited | ||||
| /// \param flag Attributes of the cache. | /// \param flag Attributes of the cache. | ||||
| explicit CreationCacheRequest(connection_id_type connection_id, uint64_t cache_mem_sz, | |||||
| CreateCacheFlag flag = CreateCacheFlag::kNone) | |||||
| : BaseRequest(connection_id, RequestType::kCreateCache), cache_mem_sz(cache_mem_sz), flag_(flag) {} | |||||
| ~CreationCacheRequest() = default; | |||||
| explicit CreateCacheRequest(const CacheClientInfo &cinfo, uint64_t cache_mem_sz, | |||||
| CreateCacheFlag flag = CreateCacheFlag::kNone); | |||||
| ~CreateCacheRequest() = default; | |||||
| void ParseResult(connection_id_type *id, std::string *out) { | |||||
| auto p = flatbuffers::GetRoot<CreateCacheReplyMsg>(reply_.result().data()); | |||||
| *id = p->connection_id(); | |||||
| *out = p->cookie()->str(); | |||||
| } | |||||
| std::string cookie() const { return cookie_; } | |||||
| /// Overload the base class Prepare | |||||
| Status Prepare() override; | |||||
| private: | private: | ||||
| uint64_t cache_mem_sz; | |||||
| uint64_t cache_mem_sz_; | |||||
| CreateCacheFlag flag_; | CreateCacheFlag flag_; | ||||
| std::string cookie_; | |||||
| }; | }; | ||||
| /// \brief Request to purge a cache. | /// \brief Request to purge a cache. | ||||
| class PurgeCacheRequest : public BaseRequest { | class PurgeCacheRequest : public BaseRequest { | ||||
| public: | public: | ||||
| friend class CacheServer; | friend class CacheServer; | ||||
| explicit PurgeCacheRequest(connection_id_type connection_id) : BaseRequest(connection_id, RequestType::kPurgeCache) {} | |||||
| explicit PurgeCacheRequest(connection_id_type connection_id) : BaseRequest(RequestType::kPurgeCache) { | |||||
| rq_.set_connection_id(connection_id); | |||||
| } | |||||
| ~PurgeCacheRequest() = default; | ~PurgeCacheRequest() = default; | ||||
| }; | }; | ||||
| /// \brief Request to destroy a cache | /// \brief Request to destroy a cache | ||||
| class DestroyCacheRequest : public BaseRequest { | class DestroyCacheRequest : public BaseRequest { | ||||
| public: | public: | ||||
| friend class CacheServer; | friend class CacheServer; | ||||
| explicit DestroyCacheRequest(connection_id_type connection_id) | |||||
| : BaseRequest(connection_id, RequestType::kDestroyCache) {} | |||||
| /// \brief Destructor | |||||
| explicit DestroyCacheRequest(connection_id_type connection_id) : BaseRequest(RequestType::kDestroyCache) { | |||||
| rq_.set_connection_id(connection_id); | |||||
| } | |||||
| ~DestroyCacheRequest() = default; | ~DestroyCacheRequest() = default; | ||||
| }; | }; | ||||
| /// \brief Obtain the statistics of the current connection | /// \brief Obtain the statistics of the current connection | ||||
| class GetStatRequest : public BaseRequest { | class GetStatRequest : public BaseRequest { | ||||
| public: | public: | ||||
| friend class CacheServer; | friend class CacheServer; | ||||
| friend class CacheService; | friend class CacheService; | ||||
| explicit GetStatRequest(connection_id_type connection_id) : BaseRequest(connection_id, RequestType::kGetStat) {} | |||||
| explicit GetStatRequest(connection_id_type connection_id) : BaseRequest(RequestType::kGetStat) { | |||||
| rq_.set_connection_id(connection_id); | |||||
| } | |||||
| ~GetStatRequest() = default; | ~GetStatRequest() = default; | ||||
| row_id_type GetMinRowId() const { | |||||
| auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer()); | |||||
| return msg->min_row_id(); | |||||
| } | |||||
| row_id_type GetMaxRowId() const { | |||||
| auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer()); | |||||
| return msg->max_row_id(); | |||||
| } | |||||
| int64_t GetNumMemCached() const { | |||||
| auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer()); | |||||
| return msg->num_mem_cached(); | |||||
| } | |||||
| int64_t GetNumDiskCached() const { | |||||
| auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer()); | |||||
| return msg->num_disk_cached(); | |||||
| } | |||||
| uint8_t GetState() const { | |||||
| auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer()); | |||||
| return msg->state(); | |||||
| /// \brief Override base function to process the result. | |||||
| Status PostReply() override; | |||||
| void GetStat(CacheServiceStat *stat) { | |||||
| if (stat != nullptr) { | |||||
| (*stat) = stat_; | |||||
| } | |||||
| } | } | ||||
| private: | private: | ||||
| MemGuard<uint8_t> mem_; | |||||
| CacheServiceStat stat_{}; | |||||
| }; | }; | ||||
| /// \brief Request to cache a schema | /// \brief Request to cache a schema | ||||
| class CacheSchemaRequest : public BaseRequest { | class CacheSchemaRequest : public BaseRequest { | ||||
| public: | public: | ||||
| friend class CacheServer; | friend class CacheServer; | ||||
| explicit CacheSchemaRequest(connection_id_type connection_id) | |||||
| : BaseRequest(connection_id, RequestType::kCacheSchema), buf_(nullptr), len_of_buf_(0) {} | |||||
| explicit CacheSchemaRequest(connection_id_type connection_id) : BaseRequest(RequestType::kCacheSchema) { | |||||
| rq_.set_connection_id(connection_id); | |||||
| } | |||||
| ~CacheSchemaRequest() = default; | ~CacheSchemaRequest() = default; | ||||
| Status SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map); | Status SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map); | ||||
| const void *GetBuffer() const { return buf_; } | |||||
| private: | |||||
| std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb_; | |||||
| const void *buf_; | |||||
| int64_t len_of_buf_; | |||||
| }; | }; | ||||
| /// \brief Request to fetch a schema | /// \brief Request to fetch a schema | ||||
| class FetchSchemaRequest : public BaseRequest { | class FetchSchemaRequest : public BaseRequest { | ||||
| public: | public: | ||||
| friend class CacheServer; | friend class CacheServer; | ||||
| explicit FetchSchemaRequest(connection_id_type connection_id) | |||||
| : BaseRequest(connection_id, RequestType::kFetchSchema) {} | |||||
| explicit FetchSchemaRequest(connection_id_type connection_id) : BaseRequest(RequestType::kFetchSchema) { | |||||
| rq_.set_connection_id(connection_id); | |||||
| } | |||||
| ~FetchSchemaRequest() = default; | ~FetchSchemaRequest() = default; | ||||
| Status PostReply() override; | |||||
| std::unordered_map<std::string, int32_t> GetColumnMap(); | std::unordered_map<std::string, int32_t> GetColumnMap(); | ||||
| private: | private: | ||||
| MemGuard<uint8_t> mem_; | |||||
| std::unordered_map<std::string, int32_t> column_name_id_map_; | std::unordered_map<std::string, int32_t> column_name_id_map_; | ||||
| }; | }; | ||||
| /// \brief Request to change a cache from build phase to read phase. Applies to non-mappable cache only. | /// \brief Request to change a cache from build phase to read phase. Applies to non-mappable cache only. | ||||
| class BuildPhaseDoneRequest : public BaseRequest { | class BuildPhaseDoneRequest : public BaseRequest { | ||||
| public: | public: | ||||
| friend class CacheServer; | friend class CacheServer; | ||||
| BuildPhaseDoneRequest(connection_id_type connection_id, const std::string &cookie) | BuildPhaseDoneRequest(connection_id_type connection_id, const std::string &cookie) | ||||
| : BaseRequest(connection_id, RequestType::kBuildPhaseDone), cookie_(cookie) {} | |||||
| : BaseRequest(RequestType::kBuildPhaseDone), cookie_(cookie) { | |||||
| rq_.set_connection_id(connection_id); | |||||
| rq_.add_buf_data(cookie_); | |||||
| } | |||||
| ~BuildPhaseDoneRequest() = default; | ~BuildPhaseDoneRequest() = default; | ||||
| private: | private: | ||||
| std::string cookie_; | std::string cookie_; | ||||
| }; | }; | ||||
| /// \brief Request to drop all the caches in the current session | |||||
| class DropSessionRequest : public BaseRequest { | |||||
| public: | |||||
| friend class CacheServer; | |||||
| explicit DropSessionRequest(const CacheClientInfo &cinfo) : BaseRequest(RequestType::kDropSession) { | |||||
| rq_.mutable_connection_info()->operator=(cinfo); | |||||
| } | |||||
| ~DropSessionRequest() = default; | |||||
| }; | |||||
| class GenerateSessionIdRequest : public BaseRequest { | |||||
| public: | |||||
| friend class CacheServer; | |||||
| GenerateSessionIdRequest() : BaseRequest(RequestType::kGenerateSessionId) { | |||||
| // We don't have anything client info nor connection id to send. But we will manually | |||||
| // set the connection id to 0. | |||||
| rq_.set_connection_id(0); | |||||
| } | |||||
| ~GenerateSessionIdRequest() = default; | |||||
| session_id_type GetSessionId() { return atoi(reply_.result().data()); } | |||||
| }; | |||||
| class AllocateSharedBlockRequest : public BaseRequest { | |||||
| public: | |||||
| friend class CacheServer; | |||||
| explicit AllocateSharedBlockRequest(connection_id_type connection_id, size_t requestedSz) | |||||
| : BaseRequest(RequestType::kAllocateSharedBlock) { | |||||
| rq_.set_connection_id(connection_id); | |||||
| rq_.add_buf_data(std::to_string(requestedSz)); | |||||
| } | |||||
| ~AllocateSharedBlockRequest() = default; | |||||
| /// \brief On return from the server, we get the (relative) address where | |||||
| /// the free block is located. | |||||
| /// \return | |||||
| int64_t GetAddr() { | |||||
| auto addr = strtoll(reply_.result().data(), nullptr, 10); | |||||
| return addr; | |||||
| } | |||||
| }; | |||||
| class ShutdownRequest : public BaseRequest { | |||||
| public: | |||||
| friend class CacheServer; | |||||
| ShutdownRequest() : BaseRequest(RequestType::kStopService) {} | |||||
| ~ShutdownRequest() = default; | |||||
| }; | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVICE_H_ | #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVICE_H_ | ||||
| @@ -14,25 +14,89 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "minddata/dataset/engine/cache/cache_server.h" | #include "minddata/dataset/engine/cache/cache_server.h" | ||||
| #include <algorithm> | |||||
| #include <functional> | |||||
| #include <limits> | |||||
| #include "minddata/dataset/core/constants.h" | |||||
| #include "minddata/dataset/engine/cache/cache_service.h" | #include "minddata/dataset/engine/cache/cache_service.h" | ||||
| #include "minddata/dataset/engine/cache/cache_request.h" | #include "minddata/dataset/engine/cache/cache_request.h" | ||||
| #include "minddata/dataset/util/bit.h" | #include "minddata/dataset/util/bit.h" | ||||
| #include "minddata/dataset/util/path.h" | |||||
| #include "minddata/dataset/util/random.h" | |||||
| #ifdef CACHE_LOCAL_CLIENT | |||||
| #include "minddata/dataset/util/sig_handler.h" | |||||
| #endif | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| CacheServer *CacheServer::instance_ = nullptr; | |||||
| std::once_flag CacheServer::init_instance_flag_; | |||||
| Status CacheServer::DoServiceStart() { | Status CacheServer::DoServiceStart() { | ||||
| #ifdef CACHE_LOCAL_CLIENT | |||||
| // We need to destroy the shared memory if user hits Control-C | |||||
| RegisterHandlers(); | |||||
| #endif | |||||
| if (!top_.empty()) { | if (!top_.empty()) { | ||||
| Path spill(top_); | Path spill(top_); | ||||
| RETURN_IF_NOT_OK(spill.CreateDirectories()); | RETURN_IF_NOT_OK(spill.CreateDirectories()); | ||||
| MS_LOG(INFO) << "CacheServer will use disk folder: " << top_; | MS_LOG(INFO) << "CacheServer will use disk folder: " << top_; | ||||
| } | } | ||||
| RETURN_IF_NOT_OK(vg_.ServiceStart()); | RETURN_IF_NOT_OK(vg_.ServiceStart()); | ||||
| cache_q_ = std::make_shared<Queue<BaseRequest *>>(1024); | |||||
| // There will be num_workers_ threads working on the grpc queue and | |||||
| // the same number of threads working on the CacheServerRequest queue. | |||||
| // Like a connector object we will set up the same number of queues but | |||||
| // we do not need to preserve any order. We will set the capacity of | |||||
| // each queue to be 128 since we are just pushing memory pointers which | |||||
| // is only 8 byte each. | |||||
| const int32_t que_capacity = 128; | |||||
| // This is the request queue from the client | |||||
| cache_q_ = std::make_shared<QueueList<CacheServerRequest *>>(); | |||||
| cache_q_->Init(num_workers_, que_capacity); | |||||
| // For the grpc completion queue to work, we need to allocate some | |||||
| // tags which in our case are instances of CacheServerQuest. | |||||
| // They got recycled and we will allocate them in advance and push | |||||
| // them into some free list. We need more (two or three times) the | |||||
| // size of the cache_q. While each worker is working on a CacheSerRequest, | |||||
| // we need some extra running injecting in the the qrpc completion queue. | |||||
| const int32_t multiplier = 3; | |||||
| const int32_t free_list_capacity = multiplier * (que_capacity + 1); | |||||
| free_list_ = std::make_shared<QueueList<CacheServerRequest *>>(); | |||||
| free_list_->Init(num_workers_, free_list_capacity); | |||||
| // We need to have a reference to the services memory pool in case | |||||
| // the Services goes out of scope earlier than us since it is a singleton | |||||
| mp_ = Services::GetInstance().GetServiceMemPool(); | |||||
| Allocator<CacheServerRequest> alloc(mp_); | |||||
| tag_.reserve(num_workers_); | |||||
| // Now we populate all free list. | |||||
| for (auto m = 0; m < num_workers_; ++m) { | |||||
| // Ideally we allocate all the free list in one malloc. But it turns out it exceeds the | |||||
| // Arena size. So we will we will allocate one segment at a time. | |||||
| auto my_tag = std::make_unique<MemGuard<CacheServerRequest, Allocator<CacheServerRequest>>>(alloc); | |||||
| // Allocate the tag and assign it the current queue | |||||
| RETURN_IF_NOT_OK(my_tag->allocate(free_list_capacity, m)); | |||||
| for (int i = 0; i < free_list_capacity; ++i) { | |||||
| RETURN_IF_NOT_OK(free_list_->operator[](m)->Add((*my_tag)[i])); | |||||
| } | |||||
| tag_.push_back(std::move(my_tag)); | |||||
| } | |||||
| RETURN_IF_NOT_OK(cache_q_->Register(&vg_)); | RETURN_IF_NOT_OK(cache_q_->Register(&vg_)); | ||||
| auto f = std::bind(&CacheServer::ServerRequest, this); | |||||
| // Spawn a a few threads to serve the request. | |||||
| RETURN_IF_NOT_OK(free_list_->Register(&vg_)); | |||||
| // Spawn a few threads to serve the real request. | |||||
| auto f = std::bind(&CacheServer::ServerRequest, this, std::placeholders::_1); | |||||
| for (auto i = 0; i < num_workers_; ++i) { | |||||
| RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Cache service worker", std::bind(f, i))); | |||||
| } | |||||
| // Start the comm layer | |||||
| try { | |||||
| comm_layer_ = std::make_shared<CacheServerGreeterImpl>(port_, shared_memory_sz_in_gb_); | |||||
| RETURN_IF_NOT_OK(comm_layer_->Run()); | |||||
| } catch (const std::exception &e) { | |||||
| RETURN_STATUS_UNEXPECTED(e.what()); | |||||
| } | |||||
| // Finally loop forever to handle the request. | |||||
| auto r = std::bind(&CacheServer::RpcRequest, this, std::placeholders::_1); | |||||
| for (auto i = 0; i < num_workers_; ++i) { | for (auto i = 0; i < num_workers_; ++i) { | ||||
| RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Cache server", f)); | |||||
| RETURN_IF_NOT_OK(vg_.CreateAsyncTask("rpc worker", std::bind(r, i))); | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -65,188 +129,551 @@ CacheService *CacheServer::GetService(connection_id_type id) const { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| Status CacheServer::CreateService(connection_id_type connection_id, uint64_t cache_mem_sz, | |||||
| BaseRequest::CreateCacheFlag flag, std::string *out_cookie) { | |||||
| Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing connection info"); | |||||
| std::string cookie; | |||||
| auto session_id = rq->connection_info().session_id(); | |||||
| auto crc = rq->connection_info().crc(); | |||||
| // We concat both numbers to form the internal connection id. | |||||
| auto connection_id = GetConnectionID(session_id, crc); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing info to create cache"); | |||||
| auto &create_cache_buf = rq->buf_data(0); | |||||
| auto p = flatbuffers::GetRoot<CreateCacheRequestMsg>(create_cache_buf.data()); | |||||
| auto flag = static_cast<CreateCacheRequest::CreateCacheFlag>(p->flag()); | |||||
| auto cache_mem_sz = p->cache_mem_sz(); | |||||
| // We can't do spilling unless this server is setup with a spill path in the first place | // We can't do spilling unless this server is setup with a spill path in the first place | ||||
| bool spill = (flag & BaseRequest::CreateCacheFlag::kSpillToDisk) == BaseRequest::CreateCacheFlag::kSpillToDisk; | |||||
| bool spill = | |||||
| (flag & CreateCacheRequest::CreateCacheFlag::kSpillToDisk) == CreateCacheRequest::CreateCacheFlag::kSpillToDisk; | |||||
| bool generate_id = | bool generate_id = | ||||
| (flag & BaseRequest::CreateCacheFlag::kGenerateRowId) == BaseRequest::CreateCacheFlag::kGenerateRowId; | |||||
| (flag & CreateCacheRequest::CreateCacheFlag::kGenerateRowId) == CreateCacheRequest::CreateCacheFlag::kGenerateRowId; | |||||
| if (spill && top_.empty()) { | if (spill && top_.empty()) { | ||||
| RETURN_STATUS_UNEXPECTED("Server is not set up with spill support."); | RETURN_STATUS_UNEXPECTED("Server is not set up with spill support."); | ||||
| } | } | ||||
| RETURN_UNEXPECTED_IF_NULL(out_cookie); | |||||
| *out_cookie = ""; | |||||
| flatbuffers::FlatBufferBuilder fbb; | |||||
| flatbuffers::Offset<flatbuffers::String> off_cookie; | |||||
| // Before creating the cache, first check if this is a request for a shared usage of an existing cache | // Before creating the cache, first check if this is a request for a shared usage of an existing cache | ||||
| // If two CreateService come in with identical connection_id, we need to serialize the create. | // If two CreateService come in with identical connection_id, we need to serialize the create. | ||||
| // The first create will be successful and be given a special cookie. | // The first create will be successful and be given a special cookie. | ||||
| UniqueLock lck(&rwLock_); | UniqueLock lck(&rwLock_); | ||||
| // Early exit if we are doing global shutdown | |||||
| if (global_shutdown_) { | |||||
| return Status::OK(); | |||||
| } | |||||
| auto end = all_caches_.end(); | auto end = all_caches_.end(); | ||||
| auto it = all_caches_.find(connection_id); | auto it = all_caches_.find(connection_id); | ||||
| bool duplicate = false; | |||||
| if (it == end) { | if (it == end) { | ||||
| std::unique_ptr<CacheService> cs; | std::unique_ptr<CacheService> cs; | ||||
| try { | try { | ||||
| cs = std::make_unique<CacheService>(cache_mem_sz, spill ? top_ : "", generate_id); | cs = std::make_unique<CacheService>(cache_mem_sz, spill ? top_ : "", generate_id); | ||||
| RETURN_IF_NOT_OK(cs->ServiceStart()); | RETURN_IF_NOT_OK(cs->ServiceStart()); | ||||
| *out_cookie = cs->cookie(); | |||||
| cookie = cs->cookie(); | |||||
| all_caches_.emplace(connection_id, std::move(cs)); | all_caches_.emplace(connection_id, std::move(cs)); | ||||
| } catch (const std::bad_alloc &e) { | } catch (const std::bad_alloc &e) { | ||||
| return Status(StatusCode::kOutOfMemory); | return Status(StatusCode::kOutOfMemory); | ||||
| } | } | ||||
| } else { | } else { | ||||
| duplicate = true; | |||||
| MS_LOG(INFO) << "Duplicate request for " + std::to_string(connection_id) + " to create cache service"; | MS_LOG(INFO) << "Duplicate request for " + std::to_string(connection_id) + " to create cache service"; | ||||
| // We can return OK but we will return a duplicate key so user can act accordingly to either ignore it | |||||
| // treat it as OK. | |||||
| return Status(StatusCode::kDuplicateKey); | |||||
| } | |||||
| off_cookie = fbb.CreateString(cookie); | |||||
| CreateCacheReplyMsgBuilder bld(fbb); | |||||
| bld.add_connection_id(connection_id); | |||||
| bld.add_cookie(off_cookie); | |||||
| auto off = bld.Finish(); | |||||
| fbb.Finish(off); | |||||
| reply->set_result(fbb.GetBufferPointer(), fbb.GetSize()); | |||||
| // Track the history of all the sessions that we have created so far. | |||||
| history_sessions_.insert(session_id); | |||||
| // We can return OK but we will return a duplicate key so user can act accordingly to either ignore it | |||||
| // treat it as OK. | |||||
| return duplicate ? Status(StatusCode::kDuplicateKey) : Status::OK(); | |||||
| } | |||||
| Status CacheServer::DestroyCache(CacheService *cs, CacheRequest *rq) { | |||||
| // We need a strong lock to protect the map. | |||||
| UniqueLock lck(&rwLock_); | |||||
| // it is already destroyed. Ignore it. | |||||
| if (cs != nullptr) { | |||||
| auto id = rq->connection_id(); | |||||
| MS_LOG(WARNING) << "Dropping cache with connection id " << std::to_string(id); | |||||
| // std::map will invoke the destructor of CacheService. So we don't need to do anything here. | |||||
| auto n = all_caches_.erase(id); | |||||
| if (n == 0) { | |||||
| // It has been destroyed by another duplicate request. | |||||
| MS_LOG(INFO) << "Duplicate request for " + std::to_string(id) + " to create cache service"; | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| inline Status CacheRow(CacheService *cs, CacheRequest *rq, CacheReply *reply) { | |||||
| auto connection_id = rq->connection_id(); | |||||
| if (cs == nullptr) { | |||||
| std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found"; | |||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||||
| } else { | |||||
| auto sz = rq->buf_data_size(); | |||||
| std::vector<const void *> buffers; | |||||
| buffers.reserve(sz); | |||||
| // First piece of data is the cookie and is required | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing cookie"); | |||||
| auto &cookie = rq->buf_data(0); | |||||
| // Only if the cookie matches, we can accept insert into this cache that has a build phase | |||||
| if (!cs->HasBuildPhase() || cookie == cs->cookie()) { | |||||
| // Push the address of each buffer (in the form of std::string coming in from protobuf) into | |||||
| // a vector of buffer | |||||
| for (auto i = 1; i < sz; ++i) { | |||||
| buffers.push_back(rq->buf_data(i).data()); | |||||
| } | |||||
| row_id_type id = -1; | |||||
| RETURN_IF_NOT_OK(cs->CacheRow(buffers, &id)); | |||||
| reply->set_result(std::to_string(id)); | |||||
| } else { | |||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch"); | |||||
| } | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| /// This is the main loop the cache server thread(s) are running. | |||||
| /// Each thread will pop a request and save the result in the same request. | |||||
| /// The sender will wait on the wait post in the request. Once the request | |||||
| /// is fulfilled, the server thread will do a post signalling the request is | |||||
| /// is processed. | |||||
| Status CacheServer::FastCacheRow(CacheService *cs, CacheRequest *rq, CacheReply *reply) { | |||||
| auto connection_id = rq->connection_id(); | |||||
| auto shared_pool = comm_layer_->GetSharedMemoryPool(); | |||||
| auto *base = shared_pool->SharedMemoryBaseAddr(); | |||||
| // Ensure we got 3 pieces of data coming in | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data_size() == 3, "Incomplete data"); | |||||
| // First piece of data is the cookie and is required | |||||
| auto &cookie = rq->buf_data(0); | |||||
| // Second piece of data is the address where we can find the serialized data | |||||
| auto addr = strtoll(rq->buf_data(1).data(), nullptr, 10); | |||||
| auto p = reinterpret_cast<void *>(reinterpret_cast<int64_t>(base) + addr); | |||||
| // Third piece of data is the size of the serialized data that we need to transfer | |||||
| auto sz = strtoll(rq->buf_data(2).data(), nullptr, 10); | |||||
| // Successful or not, we need to free the memory on exit. | |||||
| Status rc; | |||||
| if (cs == nullptr) { | |||||
| std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found"; | |||||
| rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||||
| } else { | |||||
| // Only if the cookie matches, we can accept insert into this cache that has a build phase | |||||
| if (!cs->HasBuildPhase() || cookie == cs->cookie()) { | |||||
| row_id_type id = -1; | |||||
| ReadableSlice src(p, sz); | |||||
| rc = cs->FastCacheRow(src, &id); | |||||
| reply->set_result(std::to_string(id)); | |||||
| } else { | |||||
| rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch"); | |||||
| } | |||||
| } | |||||
| // Return the block to the shared memory. | |||||
| shared_pool->Deallocate(p); | |||||
| return rc; | |||||
| } | |||||
| Status CacheServer::BatchFetchRows(CacheService *cs, CacheRequest *rq, CacheReply *reply) { | |||||
| auto connection_id = rq->connection_id(); | |||||
| if (cs == nullptr) { | |||||
| std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found"; | |||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||||
| } else { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing row id"); | |||||
| auto &row_id_buf = rq->buf_data(0); | |||||
| auto p = flatbuffers::GetRoot<TensorRowIds>(row_id_buf.data()); | |||||
| std::vector<row_id_type> row_id; | |||||
| auto sz = p->row_id()->size(); | |||||
| row_id.reserve(sz); | |||||
| for (auto i = 0; i < sz; ++i) { | |||||
| row_id.push_back(p->row_id()->Get(i)); | |||||
| } | |||||
| int64_t mem_sz = 0; | |||||
| std::vector<key_size_pair> v; | |||||
| RETURN_IF_NOT_OK(cs->PreBatchFetch(row_id, &v, &mem_sz)); | |||||
| auto client_flag = rq->flag(); | |||||
| bool local_client = BitTest(client_flag, kLocalClientSupport); | |||||
| // For large amount data to be sent back, we will use shared memory provided it is a local | |||||
| // client that has local bypass support | |||||
| bool local_bypass = local_client ? (mem_sz >= kLocalByPassThreshold) : false; | |||||
| reply->set_flag(local_bypass ? kDataIsInSharedMemory : 0); | |||||
| if (local_bypass) { | |||||
| // We will use shared memory | |||||
| auto shared_pool = comm_layer_->GetSharedMemoryPool(); | |||||
| auto *base = shared_pool->SharedMemoryBaseAddr(); | |||||
| void *q = nullptr; | |||||
| RETURN_IF_NOT_OK(shared_pool->Allocate(mem_sz, &q)); | |||||
| WritableSlice dest(q, mem_sz); | |||||
| RETURN_IF_NOT_OK(cs->BatchFetch(row_id, v, &dest)); | |||||
| // We can't return the absolute address which makes no sense to the client. | |||||
| // Instead we return the difference. | |||||
| auto difference = reinterpret_cast<int64_t>(q) - reinterpret_cast<int64_t>(base); | |||||
| reply->set_result(std::to_string(difference)); | |||||
| } else { | |||||
| // We are going to use std::string to allocate and hold the result which will be eventually | |||||
| // 'moved' to the protobuf message (which underneath is also a std::string) for the purpose | |||||
| // to minimize memory copy. | |||||
| std::string mem; | |||||
| try { | |||||
| mem.resize(mem_sz); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(mem.capacity() >= mem_sz, "Programming error"); | |||||
| } catch (const std::bad_alloc &e) { | |||||
| return Status(StatusCode::kOutOfMemory); | |||||
| } | |||||
| WritableSlice dest(mem.data(), mem_sz); | |||||
| RETURN_IF_NOT_OK(cs->BatchFetch(row_id, v, &dest)); | |||||
| reply->set_result(std::move(mem)); | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| inline Status GetStat(CacheService *cs, CacheRequest *rq, CacheReply *reply) { | |||||
| auto connection_id = rq->connection_id(); | |||||
| if (cs == nullptr) { | |||||
| std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; | |||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||||
| } else { | |||||
| CacheService::ServiceStat svc_stat; | |||||
| RETURN_IF_NOT_OK(cs->GetStat(&svc_stat)); | |||||
| flatbuffers::FlatBufferBuilder fbb; | |||||
| ServiceStatMsgBuilder bld(fbb); | |||||
| bld.add_num_disk_cached(svc_stat.stat_.num_disk_cached); | |||||
| bld.add_num_mem_cached(svc_stat.stat_.num_mem_cached); | |||||
| bld.add_avg_cache_sz(svc_stat.stat_.average_cache_sz); | |||||
| bld.add_max_row_id(svc_stat.max_); | |||||
| bld.add_min_row_id(svc_stat.min_); | |||||
| bld.add_state(svc_stat.state_); | |||||
| auto offset = bld.Finish(); | |||||
| fbb.Finish(offset); | |||||
| reply->set_result(fbb.GetBufferPointer(), fbb.GetSize()); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| inline Status CacheSchema(CacheService *cs, CacheRequest *rq) { | |||||
| auto connection_id = rq->connection_id(); | |||||
| if (cs == nullptr) { | |||||
| std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; | |||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||||
| } else { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing schema information"); | |||||
| auto &create_schema_buf = rq->buf_data(0); | |||||
| RETURN_IF_NOT_OK(cs->CacheSchema(create_schema_buf.data(), create_schema_buf.size())); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| inline Status FetchSchema(CacheService *cs, CacheRequest *rq, CacheReply *reply) { | |||||
| auto connection_id = rq->connection_id(); | |||||
| if (cs == nullptr) { | |||||
| std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; | |||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||||
| } else { | |||||
| // We are going to use std::string to allocate and hold the result which will be eventually | |||||
| // 'moved' to the protobuf message (which underneath is also a std::string) for the purpose | |||||
| // to minimize memory copy. | |||||
| std::string mem; | |||||
| RETURN_IF_NOT_OK(cs->FetchSchema(&mem)); | |||||
| reply->set_result(std::move(mem)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| inline Status BuildPhaseDone(CacheService *cs, CacheRequest *rq) { | |||||
| auto connection_id = rq->connection_id(); | |||||
| if (cs == nullptr) { | |||||
| std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; | |||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||||
| } else { | |||||
| // First piece of data is the cookie | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing cookie"); | |||||
| auto &cookie = rq->buf_data(0); | |||||
| // We can only allow to switch phase is the cookie match. | |||||
| if (cookie == cs->cookie()) { | |||||
| RETURN_IF_NOT_OK(cs->BuildPhaseDone()); | |||||
| } else { | |||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch"); | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheServer::PurgeCache(CacheService *cs) { | |||||
| SharedLock lck(&rwLock_); | |||||
| // If shutdown in progress, ignore the command. | |||||
| if (global_shutdown_) { | |||||
| return Status::OK(); | |||||
| } | |||||
| // it is already purged. Ignore it. | |||||
| if (cs != nullptr) { | |||||
| RETURN_IF_NOT_OK(cs->Purge()); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| inline Status GenerateClientSessionID(session_id_type session_id, CacheReply *reply) { | |||||
| reply->set_result(std::to_string(session_id)); | |||||
| return Status::OK(); | |||||
| } | |||||
| /// \brief This is the main loop the cache server thread(s) are running. | |||||
| /// Each thread will pop a request and send the result back to the client using grpc | |||||
| /// \return | /// \return | ||||
| Status CacheServer::ServerRequest() { | |||||
| Status CacheServer::ServerRequest(int32_t worker_id) { | |||||
| TaskManager::FindMe()->Post(); | TaskManager::FindMe()->Post(); | ||||
| // Loop forever until we are interrupted. | |||||
| while (true) { | |||||
| BaseRequest *base_rq = nullptr; | |||||
| RETURN_IF_NOT_OK(cache_q_->PopFront(&base_rq)); | |||||
| auto cs = GetService(base_rq->connection_id_); | |||||
| auto &my_que = cache_q_->operator[](worker_id); | |||||
| // Loop forever until we are interrupted or shutdown. | |||||
| while (!global_shutdown_) { | |||||
| CacheServerRequest *cache_req = nullptr; | |||||
| RETURN_IF_NOT_OK(my_que->PopFront(&cache_req)); | |||||
| auto &rq = cache_req->rq_; | |||||
| auto &reply = cache_req->reply_; | |||||
| CacheService *cs = nullptr; | |||||
| // Request comes in roughly two sets. One set is at the cache level with a connection id. | |||||
| // The other set is working at a high level and without a connection id | |||||
| if (!rq.has_connection_info()) { | |||||
| cs = GetService(rq.connection_id()); | |||||
| } | |||||
| // Except for creating a new session, we expect cs is not null. | // Except for creating a new session, we expect cs is not null. | ||||
| switch (base_rq->type_) { | |||||
| switch (cache_req->type_) { | |||||
| case BaseRequest::RequestType::kCacheRow: { | case BaseRequest::RequestType::kCacheRow: { | ||||
| if (cs == nullptr) { | |||||
| std::string errMsg = "Cache id " + std::to_string(base_rq->connection_id_) + " not found"; | |||||
| base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||||
| // Look into the flag to see where we can find the data and | |||||
| // call the appropriate method. | |||||
| auto flag = rq.flag(); | |||||
| if (BitTest(flag, kDataIsInSharedMemory)) { | |||||
| cache_req->rc_ = FastCacheRow(cs, &rq, &reply); | |||||
| } else { | } else { | ||||
| auto *rq = reinterpret_cast<CacheRowRequest *>(base_rq); | |||||
| // Only if the cookie matches, we can accept insert into this cache that has a build phase | |||||
| if (!cs->HasBuildPhase() || rq->cookie_ == cs->cookie()) { | |||||
| rq->rc_ = cs->CacheRow(rq->buffers_, &rq->row_id_from_server_); | |||||
| } else { | |||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch"); | |||||
| } | |||||
| cache_req->rc_ = CacheRow(cs, &rq, &reply); | |||||
| } | } | ||||
| break; | break; | ||||
| } | } | ||||
| case BaseRequest::RequestType::kBatchFetchRows: { | case BaseRequest::RequestType::kBatchFetchRows: { | ||||
| if (cs == nullptr) { | |||||
| std::string errMsg = "Cache id " + std::to_string(base_rq->connection_id_) + " not found"; | |||||
| base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||||
| } else { | |||||
| auto *rq = reinterpret_cast<BatchFetchRequest *>(base_rq); | |||||
| rq->rc_ = cs->BatchFetch(rq->row_id_, &rq->mem_); | |||||
| } | |||||
| cache_req->rc_ = BatchFetchRows(cs, &rq, &reply); | |||||
| break; | break; | ||||
| } | } | ||||
| case BaseRequest::RequestType::kCreateCache: { | case BaseRequest::RequestType::kCreateCache: { | ||||
| // If the cache is already created we still need to run the creation so that we do sanity checks on the | |||||
| // client id and return the cache id back to the user. | |||||
| auto *rq = reinterpret_cast<CreationCacheRequest *>(base_rq); | |||||
| rq->rc_ = CreateService(rq->connection_id_, rq->cache_mem_sz, rq->flag_, &rq->cookie_); | |||||
| cache_req->rc_ = CreateService(&rq, &reply); | |||||
| break; | break; | ||||
| } | } | ||||
| case BaseRequest::RequestType::kPurgeCache: { | case BaseRequest::RequestType::kPurgeCache: { | ||||
| if (cs != nullptr) { | |||||
| base_rq->rc_ = cs->Purge(); | |||||
| } else { | |||||
| // it is already purged. Ignore it. | |||||
| base_rq->rc_ = Status::OK(); | |||||
| } | |||||
| cache_req->rc_ = PurgeCache(cs); | |||||
| break; | break; | ||||
| } | } | ||||
| case BaseRequest::RequestType::kDestroyCache: { | case BaseRequest::RequestType::kDestroyCache: { | ||||
| if (cs != nullptr) { | |||||
| // We need a strong lock to protect the map. | |||||
| connection_id_type id = base_rq->connection_id_; | |||||
| UniqueLock lck(&rwLock_); | |||||
| // std::map will invoke the constructor of CacheService. So we don't need to do anything here. | |||||
| auto n = all_caches_.erase(id); | |||||
| if (n == 0) { | |||||
| // It has been destroyed by another duplicate request. | |||||
| MS_LOG(INFO) << "Duplicate request for " + std::to_string(id) + " to create cache service"; | |||||
| } | |||||
| base_rq->rc_ = Status::OK(); | |||||
| } else { | |||||
| // it is already destroyed. Ignore it. | |||||
| base_rq->rc_ = Status::OK(); | |||||
| } | |||||
| cache_req->rc_ = DestroyCache(cs, &rq); | |||||
| break; | break; | ||||
| } | } | ||||
| case BaseRequest::RequestType::kGetStat: { | case BaseRequest::RequestType::kGetStat: { | ||||
| if (cs == nullptr) { | |||||
| std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found"; | |||||
| base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||||
| } else { | |||||
| auto *rq = reinterpret_cast<GetStatRequest *>(base_rq); | |||||
| CacheService::ServiceStat svc_stat; | |||||
| rq->rc_ = cs->GetStat(&svc_stat); | |||||
| if (rq->rc_.IsOk()) { | |||||
| flatbuffers::FlatBufferBuilder fbb; | |||||
| ServiceStatMsgBuilder bld(fbb); | |||||
| bld.add_num_disk_cached(svc_stat.stat_.num_disk_cached); | |||||
| bld.add_num_mem_cached(svc_stat.stat_.num_mem_cached); | |||||
| bld.add_max_row_id(svc_stat.max_); | |||||
| bld.add_min_row_id(svc_stat.min_); | |||||
| bld.add_state(svc_stat.state_); | |||||
| auto offset = bld.Finish(); | |||||
| fbb.Finish(offset); | |||||
| rq->rc_ = rq->mem_.allocate(fbb.GetSize()); | |||||
| if (rq->rc_.IsOk()) { | |||||
| WritableSlice dest(rq->mem_.GetMutablePointer(), fbb.GetSize()); | |||||
| ReadableSlice src(fbb.GetBufferPointer(), fbb.GetSize()); | |||||
| RETURN_IF_NOT_OK(WritableSlice::Copy(&dest, src)); | |||||
| } | |||||
| } | |||||
| } | |||||
| cache_req->rc_ = GetStat(cs, &rq, &reply); | |||||
| break; | break; | ||||
| } | } | ||||
| case BaseRequest::RequestType::kCacheSchema: { | case BaseRequest::RequestType::kCacheSchema: { | ||||
| if (cs == nullptr) { | |||||
| std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found"; | |||||
| base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||||
| } else { | |||||
| auto *rq = reinterpret_cast<CacheSchemaRequest *>(base_rq); | |||||
| rq->rc_ = cs->CacheSchema(rq->buf_, rq->len_of_buf_); | |||||
| } | |||||
| cache_req->rc_ = CacheSchema(cs, &rq); | |||||
| break; | break; | ||||
| } | } | ||||
| case BaseRequest::RequestType::kFetchSchema: { | case BaseRequest::RequestType::kFetchSchema: { | ||||
| if (cs == nullptr) { | |||||
| std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found"; | |||||
| base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||||
| } else { | |||||
| auto *rq = reinterpret_cast<FetchSchemaRequest *>(base_rq); | |||||
| rq->rc_ = cs->FetchSchema(&rq->mem_); | |||||
| } | |||||
| cache_req->rc_ = FetchSchema(cs, &rq, &reply); | |||||
| break; | break; | ||||
| } | } | ||||
| case BaseRequest::RequestType::kBuildPhaseDone: { | case BaseRequest::RequestType::kBuildPhaseDone: { | ||||
| if (cs == nullptr) { | |||||
| std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found"; | |||||
| base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||||
| } else { | |||||
| auto *rq = reinterpret_cast<BuildPhaseDoneRequest *>(base_rq); | |||||
| // We can only allow to switch phase is the cookie match. | |||||
| if (rq->cookie_ == cs->cookie()) { | |||||
| rq->rc_ = cs->BuildPhaseDone(); | |||||
| } else { | |||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch"); | |||||
| } | |||||
| } | |||||
| cache_req->rc_ = BuildPhaseDone(cs, &rq); | |||||
| break; | |||||
| } | |||||
| case BaseRequest::RequestType::kDropSession: { | |||||
| cache_req->rc_ = DestroySession(&rq); | |||||
| break; | |||||
| } | |||||
| case BaseRequest::RequestType::kGenerateSessionId: { | |||||
| cache_req->rc_ = GenerateClientSessionID(GenerateSessionID(), &reply); | |||||
| break; | |||||
| } | |||||
| case BaseRequest::RequestType::kAllocateSharedBlock: { | |||||
| cache_req->rc_ = AllocateSharedMemory(&rq, &reply); | |||||
| break; | |||||
| } | |||||
| case BaseRequest::RequestType::kFreeSharedBlock: { | |||||
| cache_req->rc_ = FreeSharedMemory(&rq); | |||||
| break; | |||||
| } | |||||
| case BaseRequest::RequestType::kStopService: { | |||||
| // This command shutdowns everything. | |||||
| cache_req->rc_ = GlobalShutdown(); | |||||
| break; | break; | ||||
| } | } | ||||
| default: | default: | ||||
| base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Unknown request type"); | |||||
| std::string errMsg("Unknown request type : "); | |||||
| errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_)); | |||||
| cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||||
| } | } | ||||
| // Notify it is done, and move on to the next request. | // Notify it is done, and move on to the next request. | ||||
| base_rq->wp_.Set(); | |||||
| Status2CacheReply(cache_req->rc_, &reply); | |||||
| cache_req->st_ = CacheServerRequest::STATE::FINISH; | |||||
| // We will re-tag the request back to the grpc queue. Once it comes back from the client, | |||||
| // the CacheServerRequest, i.e. the pointer cache_req, will be free | |||||
| if (!global_shutdown_) { | |||||
| cache_req->responder_.Finish(reply, grpc::Status::OK, cache_req); | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| connection_id_type CacheServer::GetConnectionID(session_id_type session_id, uint32_t crc) const { | |||||
| connection_id_type connection_id = | |||||
| (static_cast<connection_id_type>(session_id) << 32u) | static_cast<connection_id_type>(crc); | |||||
| return connection_id; | |||||
| } | |||||
| session_id_type CacheServer::GetSessionID(connection_id_type connection_id) const { | |||||
| return static_cast<session_id_type>(connection_id >> 32u); | |||||
| } | |||||
| CacheServer::CacheServer(const std::string &spill_path, int32_t num_workers, int32_t port, | |||||
| int32_t shared_meory_sz_in_gb) | |||||
| : top_(spill_path), | |||||
| num_workers_(num_workers), | |||||
| port_(port), | |||||
| shared_memory_sz_in_gb_(shared_meory_sz_in_gb), | |||||
| global_shutdown_(false) {} | |||||
| Status CacheServer::Run() { | |||||
| RETURN_IF_NOT_OK(ServiceStart()); | |||||
| // This is called by the main function and we shouldn't exit. Otherwise the main thread | |||||
| // will just shutdown. So we will call some function that never return unless error. | |||||
| // One good case will be simply to wait for all threads to return. | |||||
| RETURN_IF_NOT_OK(vg_.join_all(Task::WaitFlag::kBlocking)); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheServer::GetFreeRequestTag(int32_t queue_id, CacheServerRequest **q) { | |||||
| RETURN_UNEXPECTED_IF_NULL(q); | |||||
| CacheServer &cs = CacheServer::GetInstance(); | |||||
| CacheServerRequest *p; | |||||
| RETURN_IF_NOT_OK(cs.free_list_->operator[](queue_id)->PopFront(&p)); | |||||
| *q = p; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheServer::ReturnRequestTag(CacheServerRequest *p) { | |||||
| RETURN_UNEXPECTED_IF_NULL(p); | |||||
| int32_t myQID = p->getQid(); | |||||
| // Free any memory from the protobufs | |||||
| p->~CacheServerRequest(); | |||||
| // Re-initialize the memory | |||||
| new (p) CacheServerRequest(myQID); | |||||
| // Now we return it back to free list. | |||||
| CacheServer &cs = CacheServer::GetInstance(); | |||||
| RETURN_IF_NOT_OK(cs.free_list_->operator[](myQID)->Add(p)); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheServer::DestroySession(CacheRequest *rq) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing session id"); | |||||
| auto drop_session_id = rq->connection_info().session_id(); | |||||
| UniqueLock lck(&rwLock_); | |||||
| for (auto &cs : all_caches_) { | |||||
| auto connection_id = cs.first; | |||||
| auto session_id = GetSessionID(connection_id); | |||||
| // We can just call DestroyCache() but we are holding a lock already. Doing so will cause deadlock. | |||||
| // So we will just manually do it. | |||||
| if (session_id == drop_session_id) { | |||||
| // std::map will invoke the destructor of CacheService. So we don't need to do anything here. | |||||
| auto n = all_caches_.erase(connection_id); | |||||
| MS_LOG(INFO) << "Destroy " << n << " copies of cache with id " << connection_id; | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| session_id_type CacheServer::GenerateSessionID() const { | |||||
| SharedLock lock(&rwLock_); | |||||
| auto mt = GetRandomDevice(); | |||||
| std::uniform_int_distribution<session_id_type> distribution(0, std::numeric_limits<session_id_type>::max()); | |||||
| session_id_type session_id; | |||||
| bool duplicate = false; | |||||
| do { | |||||
| session_id = distribution(mt); | |||||
| auto it = history_sessions_.find(session_id); | |||||
| duplicate = (it != history_sessions_.end()); | |||||
| } while (duplicate); | |||||
| return session_id; | |||||
| } | |||||
| Status CacheServer::AllocateSharedMemory(CacheRequest *rq, CacheReply *reply) { | |||||
| auto requestedSz = strtoll(rq->buf_data(0).data(), nullptr, 10); | |||||
| auto shared_pool = comm_layer_->GetSharedMemoryPool(); | |||||
| auto *base = shared_pool->SharedMemoryBaseAddr(); | |||||
| void *p = nullptr; | |||||
| RETURN_IF_NOT_OK(shared_pool->Allocate(requestedSz, &p)); | |||||
| // We can't return the absolute address which makes no sense to the client. | |||||
| // Instead we return the difference. | |||||
| auto difference = reinterpret_cast<int64_t>(p) - reinterpret_cast<int64_t>(base); | |||||
| reply->set_result(std::to_string(difference)); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheServer::FreeSharedMemory(CacheRequest *rq) { | |||||
| auto shared_pool = comm_layer_->GetSharedMemoryPool(); | |||||
| auto *base = shared_pool->SharedMemoryBaseAddr(); | |||||
| auto addr = strtoll(rq->buf_data(0).data(), nullptr, 10); | |||||
| auto p = reinterpret_cast<void *>(reinterpret_cast<int64_t>(base) + addr); | |||||
| shared_pool->Deallocate(p); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheServer::RpcRequest(int32_t worker_id) { | |||||
| TaskManager::FindMe()->Post(); | |||||
| RETURN_IF_NOT_OK(comm_layer_->HandleRequest(worker_id)); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheServer::GlobalShutdown() { | |||||
| // Let's shutdown in proper order. | |||||
| bool expected = false; | |||||
| if (global_shutdown_.compare_exchange_strong(expected, true)) { | |||||
| MS_LOG(WARNING) << "Shutting down server."; | |||||
| // Shutdown the grpc queue. No longer accept any new comer. | |||||
| // The threads we spawn to work on the grpc queue will exit themselves once | |||||
| // they notice the queue has been shutdown. | |||||
| comm_layer_->Shutdown(); | |||||
| // Now we interrupt any threads that are waiting on cache_q_ | |||||
| vg_.interrupt_all(); | |||||
| // The next thing to do drop all the caches. | |||||
| UniqueLock lck(&rwLock_); | |||||
| for (auto &it : all_caches_) { | |||||
| auto id = it.first; | |||||
| MS_LOG(WARNING) << "Dropping cache with connection id " << std::to_string(id); | |||||
| // Wait for all outstanding work to be finished. | |||||
| auto &cs = it.second; | |||||
| UniqueLock cs_lock(&cs->rw_lock_); | |||||
| // std::map will invoke the destructor of CacheService. So we don't need to do anything here. | |||||
| (void)all_caches_.erase(id); | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheServer::Builder::SanityCheck() { | |||||
| if (shared_memory_sz_in_gb_ <= 0) { | |||||
| RETURN_STATUS_UNEXPECTED("Shared memory size (in GB unit) must be positive"); | |||||
| } | |||||
| if (num_workers_ <= 0) { | |||||
| RETURN_STATUS_UNEXPECTED("Number of parallel workers must be positive"); | |||||
| } | |||||
| if (!top_.empty()) { | |||||
| auto p = top_.data(); | |||||
| if (p[0] != '/') { | |||||
| RETURN_STATUS_UNEXPECTED("Spilling directory must be an absolute path"); | |||||
| } | |||||
| // Check if the spill directory is writable | |||||
| Path spill(top_); | |||||
| auto t = spill / Services::GetUniqueID(); | |||||
| Status rc = t.CreateDirectory(); | |||||
| if (rc.IsOk()) { | |||||
| rc = t.Remove(); | |||||
| } | |||||
| if (rc.IsError()) { | |||||
| RETURN_STATUS_UNEXPECTED("Spilling directory is not writable\n" + rc.ToString()); | |||||
| } | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| CacheServer::CacheServer(const std::string &spill_path, int32_t num_workers) | |||||
| : top_(spill_path), num_workers_(num_workers) {} | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -24,8 +24,11 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include <map> | #include <map> | ||||
| #include <set> | |||||
| #include "minddata/dataset/engine/cache/cache_service.h" | #include "minddata/dataset/engine/cache/cache_service.h" | ||||
| #include "minddata/dataset/engine/cache/cache_grpc_server.h" | |||||
| #include "minddata/dataset/core/tensor.h" | #include "minddata/dataset/core/tensor.h" | ||||
| #include "minddata/dataset/util/allocator.h" | |||||
| #include "minddata/dataset/util/arena.h" | #include "minddata/dataset/util/arena.h" | ||||
| #include "minddata/dataset/util/cache_pool.h" | #include "minddata/dataset/util/cache_pool.h" | ||||
| #include "minddata/dataset/util/lock.h" | #include "minddata/dataset/util/lock.h" | ||||
| @@ -37,43 +40,131 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| class BaseRequest; | |||||
| /// \brief A server which provides CacheService services. | /// \brief A server which provides CacheService services. | ||||
| class CacheServer : public Service { | class CacheServer : public Service { | ||||
| public: | public: | ||||
| friend class Services; | friend class Services; | ||||
| using cache_index = std::map<connection_id_type, std::unique_ptr<CacheService>>; | using cache_index = std::map<connection_id_type, std::unique_ptr<CacheService>>; | ||||
| class Builder { | |||||
| public: | |||||
| Builder() : top_("/tmp"), num_workers_(32), port_(50052), shared_memory_sz_in_gb_(4) {} | |||||
| /// \brief Getter functions | |||||
| const std::string &getTop() const { return top_; } | |||||
| int32_t getNumWorkers() const { return num_workers_; } | |||||
| int32_t getPort() const { return port_; } | |||||
| int32_t getSharedMemorySzInGb() const { return shared_memory_sz_in_gb_; } | |||||
| Builder &SetRootDirectory(std::string root) { | |||||
| top_ = std::move(root); | |||||
| return *this; | |||||
| } | |||||
| Builder &SetNumWorkers(int32_t n) { | |||||
| num_workers_ = n; | |||||
| return *this; | |||||
| } | |||||
| Builder &SetPort(int32_t p) { | |||||
| port_ = p; | |||||
| return *this; | |||||
| } | |||||
| Builder &SetSharedMemorySizeInGB(int32_t sz) { | |||||
| shared_memory_sz_in_gb_ = sz; | |||||
| return *this; | |||||
| } | |||||
| Status SanityCheck(); | |||||
| void Print(std::ostream &out) const { | |||||
| out << "Summary of the cache server configuration\n" | |||||
| << "Spill directory: " << getTop() << "\n" | |||||
| << "Number of parallel workers: " << getNumWorkers() << "\n" | |||||
| << "Tcp/ip port: " << getPort() << "\n" | |||||
| << "Shared memory size (in GB): " << getSharedMemorySzInGb(); | |||||
| } | |||||
| friend std::ostream &operator<<(std::ostream &out, const Builder &bld) { | |||||
| bld.Print(out); | |||||
| return out; | |||||
| } | |||||
| Status Build() { | |||||
| RETURN_IF_NOT_OK(SanityCheck()); | |||||
| // We need to bring up the Task Manager by bringing up the Services singleton. | |||||
| RETURN_IF_NOT_OK(Services::CreateInstance()); | |||||
| RETURN_IF_NOT_OK(CacheServer::CreateInstance(top_, num_workers_, port_, shared_memory_sz_in_gb_)); | |||||
| return Status::OK(); | |||||
| } | |||||
| private: | |||||
| std::string top_; | |||||
| int32_t num_workers_; | |||||
| int32_t port_; | |||||
| int32_t shared_memory_sz_in_gb_; | |||||
| }; | |||||
| CacheServer(const CacheServer &) = delete; | CacheServer(const CacheServer &) = delete; | ||||
| CacheServer &operator=(const CacheServer &) = delete; | CacheServer &operator=(const CacheServer &) = delete; | ||||
| CacheServer(CacheServer &&) = delete; | CacheServer(CacheServer &&) = delete; | ||||
| CacheServer &operator=(CacheServer &) = delete; | CacheServer &operator=(CacheServer &) = delete; | ||||
| static CacheServer &GetInstance() noexcept { return Services::getCacheServer(); } | |||||
| Status DoServiceStart() override; | Status DoServiceStart() override; | ||||
| Status DoServiceStop() override; | Status DoServiceStop() override; | ||||
| ~CacheServer() { (void)ServiceStop(); } | ~CacheServer() { (void)ServiceStop(); } | ||||
| static Status CreateInstance(const std::string &spill_path, int32_t num_workers, int32_t port, | |||||
| int32_t shared_memory_sz) { | |||||
| std::call_once(init_instance_flag_, [&]() -> Status { | |||||
| auto &svcManager = Services::GetInstance(); | |||||
| RETURN_IF_NOT_OK(svcManager.AddHook(&instance_, spill_path, num_workers, port, shared_memory_sz)); | |||||
| return Status::OK(); | |||||
| }); | |||||
| return Status::OK(); | |||||
| } | |||||
| static CacheServer &GetInstance() { return *instance_; } | |||||
| /// \brief For the current demonstration, a cache client contacts cache server using a Queue. | /// \brief For the current demonstration, a cache client contacts cache server using a Queue. | ||||
| /// \param rq | /// \param rq | ||||
| /// \return Status object | /// \return Status object | ||||
| Status PushRequest(BaseRequest *rq) { | |||||
| Status PushRequest(int32_t queue_id, CacheServerRequest *rq) { | |||||
| RETURN_UNEXPECTED_IF_NULL(rq); | RETURN_UNEXPECTED_IF_NULL(rq); | ||||
| RETURN_IF_NOT_OK(cache_q_->Add(rq)); | |||||
| RETURN_IF_NOT_OK(cache_q_->operator[](queue_id)->Add(rq)); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| /// \\brief Kick off server threads. Never return unless error out. | |||||
| Status Run(); | |||||
| /// \brief Get a free tag | |||||
| /// \param q[in] pointer to a pointer to a CacheServerRequest | |||||
| /// \return Status object | |||||
| static Status GetFreeRequestTag(int32_t queue_id, CacheServerRequest **q); | |||||
| /// \brief Return a tag to the free list | |||||
| /// \param p[in] pointer to already finished CacheServerRequest tag | |||||
| /// \return Status object | |||||
| static Status ReturnRequestTag(CacheServerRequest *p); | |||||
| private: | private: | ||||
| static std::once_flag init_instance_flag_; | |||||
| static CacheServer *instance_; | |||||
| mutable RWLock rwLock_; | mutable RWLock rwLock_; | ||||
| std::string top_; | std::string top_; | ||||
| cache_index all_caches_; | cache_index all_caches_; | ||||
| std::shared_ptr<Queue<BaseRequest *>> cache_q_; | |||||
| std::set<session_id_type> history_sessions_; | |||||
| std::shared_ptr<QueueList<CacheServerRequest *>> cache_q_; | |||||
| std::shared_ptr<QueueList<CacheServerRequest *>> free_list_; | |||||
| std::vector<std::unique_ptr<MemGuard<CacheServerRequest, Allocator<CacheServerRequest>>>> tag_; | |||||
| std::shared_ptr<CacheServerGreeterImpl> comm_layer_; | |||||
| std::shared_ptr<MemoryPool> mp_; | |||||
| TaskGroup vg_; | TaskGroup vg_; | ||||
| int32_t num_workers_; | int32_t num_workers_; | ||||
| int32_t port_; | |||||
| int32_t shared_memory_sz_in_gb_; | |||||
| std::atomic<bool> global_shutdown_; | |||||
| /// \brief Constructor | /// \brief Constructor | ||||
| /// \param spill_path Top directory for spilling buffers to. | /// \param spill_path Top directory for spilling buffers to. | ||||
| /// \param num_workers Number of threads for handling requests. | /// \param num_workers Number of threads for handling requests. | ||||
| explicit CacheServer(const std::string &spill_path, int32_t num_workers = 3); | |||||
| explicit CacheServer(const std::string &spill_path, int32_t num_workers, int32_t port, int32_t share_memory_sz_in_gb); | |||||
| /// \brief Locate a cache service from connection id. | /// \brief Locate a cache service from connection id. | ||||
| /// \return Pointer to cache service. Null if not found | /// \return Pointer to cache service. Null if not found | ||||
| @@ -82,16 +173,65 @@ class CacheServer : public Service { | |||||
| /// \brief Create a cache service. We allow multiple clients to create the same cache service. | /// \brief Create a cache service. We allow multiple clients to create the same cache service. | ||||
| /// Subsequent duplicate requests are ignored. The first cache client to create the service will be given | /// Subsequent duplicate requests are ignored. The first cache client to create the service will be given | ||||
| /// a special unique cookie. | /// a special unique cookie. | ||||
| /// \param[in] connection_id This is from a Cache client. | |||||
| /// \param[in] cache_mem_sz | |||||
| /// \param[in] flag | |||||
| /// \param[out] out_cookie Only the first cache client will be given a special cookie to identify the creator | |||||
| /// \return Status object | /// \return Status object | ||||
| Status CreateService(connection_id_type connection_id, uint64_t cache_mem_sz, BaseRequest::CreateCacheFlag flag, | |||||
| std::string *out_cookie); | |||||
| Status CreateService(CacheRequest *rq, CacheReply *reply); | |||||
| /// \brief Destroy a cache service | |||||
| /// \param cs | |||||
| /// \param rq | |||||
| /// \return | |||||
| Status DestroyCache(CacheService *cs, CacheRequest *rq); | |||||
| Status PurgeCache(CacheService *cs); | |||||
| /// \brief Entry point for all internal server threads. | |||||
| Status ServerRequest(int32_t worker_id); | |||||
| /// \brief Entry point for all grpc threads. | |||||
| /// \return | |||||
| Status RpcRequest(int32_t worker_id); | |||||
| Status DestroySession(CacheRequest *rq); | |||||
| /// \brief Create a connection id from a session id and a crc | |||||
| /// \param session_id | |||||
| /// \param crc | |||||
| /// \return connection id | |||||
| connection_id_type GetConnectionID(session_id_type session_id, uint32_t crc) const; | |||||
| /// \brief Extract the session id from a connection id | |||||
| /// \param connection_id | |||||
| /// \return session id | |||||
| session_id_type GetSessionID(connection_id_type connection_id) const; | |||||
| /// \brief Generate a session ID for the client | |||||
| /// \return Session ID | |||||
| session_id_type GenerateSessionID() const; | |||||
| /// \brief Handle kAllocateSharedBlock request | |||||
| /// \param rq CacheRequest | |||||
| /// \param reply CacheReply | |||||
| /// \return Status object | |||||
| Status AllocateSharedMemory(CacheRequest *rq, CacheReply *reply); | |||||
| /// \brief Handle kFreeSharedBlock request | |||||
| /// \param rq | |||||
| /// \return Status object | |||||
| Status FreeSharedMemory(CacheRequest *rq); | |||||
| /// \brief Entry point for all server threads. | |||||
| Status ServerRequest(); | |||||
| /// \brief Handle kFastCacheRow request | |||||
| /// \return Status object | |||||
| Status FastCacheRow(CacheService *cs, CacheRequest *rq, CacheReply *reply); | |||||
| /// \brief Internal function to do row batch fetch | |||||
| /// \param cs CacheService | |||||
| /// \param rq Request | |||||
| /// \param reply Reply | |||||
| /// \return | |||||
| Status BatchFetchRows(CacheService *cs, CacheRequest *rq, CacheReply *reply); | |||||
| /// \brief A proper shutdown of the server | |||||
| /// \return Status object | |||||
| Status GlobalShutdown(); | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -76,7 +76,7 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type | |||||
| *row_id_generated = GetNextRowId(); | *row_id_generated = GetNextRowId(); | ||||
| // Some debug information on how many rows we have generated so far. | // Some debug information on how many rows we have generated so far. | ||||
| if ((*row_id_generated) % 1000 == 0) { | if ((*row_id_generated) % 1000 == 0) { | ||||
| MS_LOG(DEBUG) << "Number of rows cached: " << *row_id_generated; | |||||
| MS_LOG(DEBUG) << "Number of rows cached: " << (*row_id_generated) + 1; | |||||
| } | } | ||||
| } else { | } else { | ||||
| if (msg->row_id() < 0) { | if (msg->row_id() < 0) { | ||||
| @@ -114,6 +114,45 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type | |||||
| RETURN_STATUS_UNEXPECTED(e.what()); | RETURN_STATUS_UNEXPECTED(e.what()); | ||||
| } | } | ||||
| } | } | ||||
| Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_generated) { | |||||
| SharedLock rw(&rw_lock_); | |||||
| RETURN_UNEXPECTED_IF_NULL(row_id_generated); | |||||
| if (st_ == State::kFetchPhase) { | |||||
| // For this kind of cache service, once we are done with the build phase into fetch phase, we can't | |||||
| // allow other to cache more rows. | |||||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | |||||
| } | |||||
| try { | |||||
| // If we don't need to generate id, we need to find it from the buffer. | |||||
| if (generate_id_) { | |||||
| *row_id_generated = GetNextRowId(); | |||||
| // Some debug information on how many rows we have generated so far. | |||||
| if ((*row_id_generated) % 1000 == 0) { | |||||
| MS_LOG(DEBUG) << "Number of rows cached: " << (*row_id_generated) + 1; | |||||
| } | |||||
| } else { | |||||
| auto msg = GetTensorRowHeaderMsg(src.GetPointer()); | |||||
| if (msg->row_id() < 0) { | |||||
| std::string errMsg = "Expect positive row id: " + std::to_string(msg->row_id()); | |||||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||||
| } | |||||
| *row_id_generated = msg->row_id(); | |||||
| } | |||||
| // Now we cache the flat buffer. | |||||
| CachePool::key_type key; | |||||
| RETURN_IF_NOT_OK(cp_->Insert({src}, &key)); | |||||
| Status rc = map_->DoInsert(*row_id_generated, key); | |||||
| if (rc == Status(StatusCode::kDuplicateKey)) { | |||||
| MS_LOG(DEBUG) << "Ignoring duplicate key."; | |||||
| } else { | |||||
| RETURN_IF_NOT_OK(rc); | |||||
| } | |||||
| return Status::OK(); | |||||
| } catch (const std::exception &e) { | |||||
| RETURN_STATUS_UNEXPECTED(e.what()); | |||||
| } | |||||
| } | |||||
| std::ostream &operator<<(std::ostream &out, const CacheService &cs) { | std::ostream &operator<<(std::ostream &out, const CacheService &cs) { | ||||
| // Then show any custom derived-internal stuff | // Then show any custom derived-internal stuff | ||||
| out << "\nCache memory size: " << cs.cache_mem_sz_; | out << "\nCache memory size: " << cs.cache_mem_sz_; | ||||
| @@ -155,20 +194,15 @@ Status CacheService::GetStat(CacheService::ServiceStat *out) { | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheService::BatchFetch(const std::vector<row_id_type> &v, MemGuard<uint8_t> *out) const { | |||||
| RETURN_UNEXPECTED_IF_NULL(out); | |||||
| Status CacheService::PreBatchFetch(const std::vector<row_id_type> &v, std::vector<key_size_pair> *out, | |||||
| int64_t *mem_sz) { | |||||
| SharedLock rw(&rw_lock_); | SharedLock rw(&rw_lock_); | ||||
| if (st_ == State::kBuildPhase) { | |||||
| // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. | |||||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | |||||
| } | |||||
| RETURN_UNEXPECTED_IF_NULL(out); | |||||
| RETURN_UNEXPECTED_IF_NULL(mem_sz); | |||||
| const auto num_elements = v.size(); | const auto num_elements = v.size(); | ||||
| int64_t mem_sz = (num_elements + 1) * sizeof(int64_t); | |||||
| int64_t data_offset = mem_sz; | |||||
| std::vector<int64_t> sz_v; | |||||
| std::vector<CachePool::key_type> keys; | |||||
| sz_v.reserve(num_elements); | |||||
| keys.reserve(num_elements); | |||||
| *mem_sz = (num_elements + 1) * sizeof(int64_t); | |||||
| (*out).reserve(num_elements); | |||||
| for (auto row_id : v) { | for (auto row_id : v) { | ||||
| auto r = map_->Search(row_id); | auto r = map_->Search(row_id); | ||||
| if (r.second) { | if (r.second) { | ||||
| @@ -180,25 +214,33 @@ Status CacheService::BatchFetch(const std::vector<row_id_type> &v, MemGuard<uint | |||||
| errMsg += std::to_string(key); | errMsg += std::to_string(key); | ||||
| RETURN_STATUS_UNEXPECTED(errMsg); | RETURN_STATUS_UNEXPECTED(errMsg); | ||||
| } | } | ||||
| keys.push_back(key); | |||||
| sz_v.push_back(sz); | |||||
| mem_sz += sz; | |||||
| (*out).emplace_back(key, sz); | |||||
| (*mem_sz) += sz; | |||||
| } else { | } else { | ||||
| keys.push_back(-1); | |||||
| sz_v.push_back(0); | |||||
| (*out).emplace_back(-1, 0); | |||||
| } | } | ||||
| } | } | ||||
| MemGuard<uint8_t> mem; | |||||
| RETURN_IF_NOT_OK(mem.allocate(mem_sz)); | |||||
| auto *offset_array = reinterpret_cast<int64_t *>(mem.GetMutablePointer()); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheService::BatchFetch(const std::vector<row_id_type> &v, const std::vector<key_size_pair> &info, | |||||
| WritableSlice *out) const { | |||||
| RETURN_UNEXPECTED_IF_NULL(out); | |||||
| SharedLock rw(&rw_lock_); | |||||
| if (st_ == State::kBuildPhase) { | |||||
| // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. | |||||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | |||||
| } | |||||
| const auto num_elements = v.size(); | |||||
| int64_t data_offset = (num_elements + 1) * sizeof(int64_t); | |||||
| auto *offset_array = reinterpret_cast<int64_t *>(out->GetMutablePointer()); | |||||
| offset_array[0] = data_offset; | offset_array[0] = data_offset; | ||||
| WritableSlice all(mem.GetMutablePointer(), mem.GetSizeInBytes()); | |||||
| for (auto i = 0; i < num_elements; ++i) { | for (auto i = 0; i < num_elements; ++i) { | ||||
| auto sz = sz_v.at(i); | |||||
| auto sz = info.at(i).second; | |||||
| offset_array[i + 1] = offset_array[i] + sz; | offset_array[i + 1] = offset_array[i] + sz; | ||||
| if (sz > 0) { | if (sz > 0) { | ||||
| WritableSlice row_data(all, offset_array[i], sz); | |||||
| auto key = keys.at(i); | |||||
| WritableSlice row_data(*out, offset_array[i], sz); | |||||
| auto key = info.at(i).first; | |||||
| size_t bytesRead = 0; | size_t bytesRead = 0; | ||||
| RETURN_IF_NOT_OK(cp_->Read(key, &row_data, &bytesRead)); | RETURN_IF_NOT_OK(cp_->Read(key, &row_data, &bytesRead)); | ||||
| if (bytesRead != sz) { | if (bytesRead != sz) { | ||||
| @@ -208,7 +250,6 @@ Status CacheService::BatchFetch(const std::vector<row_id_type> &v, MemGuard<uint | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| *out = std::move(mem); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheService::CacheSchema(const void *buf, int64_t len) { | Status CacheService::CacheSchema(const void *buf, int64_t len) { | ||||
| @@ -232,18 +273,26 @@ Status CacheService::CacheSchema(const void *buf, int64_t len) { | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheService::FetchSchema(MemGuard<uint8_t> *out) const { | |||||
| Status CacheService::FetchSchema(std::string *out) const { | |||||
| SharedLock rw(&rw_lock_); | SharedLock rw(&rw_lock_); | ||||
| if (st_ == State::kBuildPhase) { | if (st_ == State::kBuildPhase) { | ||||
| // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. | // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. | ||||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | ||||
| } | } | ||||
| RETURN_UNEXPECTED_IF_NULL(out); | RETURN_UNEXPECTED_IF_NULL(out); | ||||
| MemGuard<uint8_t> mem; | |||||
| // We are going to use std::string to allocate and hold the result which will be eventually | |||||
| // 'moved' to the protobuf message (which underneath is also a std::string) for the purpose | |||||
| // to minimize memory copy. | |||||
| std::string mem; | |||||
| if (schema_key_ >= 0) { | if (schema_key_ >= 0) { | ||||
| auto len = cp_->GetSize(schema_key_); | auto len = cp_->GetSize(schema_key_); | ||||
| RETURN_IF_NOT_OK(mem.allocate(len)); | |||||
| auto slice = WritableSlice(mem.GetMutablePointer(), len); | |||||
| try { | |||||
| mem.resize(len); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(mem.capacity() >= len, "Programming error"); | |||||
| } catch (const std::bad_alloc &e) { | |||||
| return Status(StatusCode::kOutOfMemory); | |||||
| } | |||||
| auto slice = WritableSlice(mem.data(), len); | |||||
| RETURN_IF_NOT_OK(cp_->Read(schema_key_, &slice)); | RETURN_IF_NOT_OK(cp_->Read(schema_key_, &slice)); | ||||
| *out = std::move(mem); | *out = std::move(mem); | ||||
| } else { | } else { | ||||
| @@ -28,7 +28,6 @@ | |||||
| #include "minddata/dataset/core/global_context.h" | #include "minddata/dataset/core/global_context.h" | ||||
| #include "minddata/dataset/core/tensor.h" | #include "minddata/dataset/core/tensor.h" | ||||
| #include "minddata/dataset/engine/cache/cache_request.h" | #include "minddata/dataset/engine/cache/cache_request.h" | ||||
| #include "minddata/dataset/engine/cache/de_tensor_generated.h" | |||||
| #include "minddata/dataset/util/arena.h" | #include "minddata/dataset/util/arena.h" | ||||
| #include "minddata/dataset/util/btree.h" | #include "minddata/dataset/util/btree.h" | ||||
| #include "minddata/dataset/util/cache_pool.h" | #include "minddata/dataset/util/cache_pool.h" | ||||
| @@ -38,7 +37,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| struct CacheStat; | |||||
| /// Some typedef used for BatchFetch | |||||
| using key_size_pair = std::pair<CachePool::key_type, size_t>; | |||||
| /// \brief A cache service for storing/fetching buffers to in memory cache and may spill to disk the cache service is | /// \brief A cache service for storing/fetching buffers to in memory cache and may spill to disk the cache service is | ||||
| /// created to support spilling | /// created to support spilling | ||||
| class CacheService : public Service { | class CacheService : public Service { | ||||
| @@ -69,12 +69,26 @@ class CacheService : public Service { | |||||
| /// \param[out] row_id_generated The row id assigned to this row if any | /// \param[out] row_id_generated The row id assigned to this row if any | ||||
| /// \return Status object | /// \return Status object | ||||
| Status CacheRow(const std::vector<const void *> &buf, row_id_type *row_id_generated); | Status CacheRow(const std::vector<const void *> &buf, row_id_type *row_id_generated); | ||||
| /// \brief A fast version of CacheRow where all the data is already in one contiguous piece. | |||||
| /// \param src Slice of the data | |||||
| /// \param row_id_generated | |||||
| /// \return Status object | |||||
| Status FastCacheRow(const ReadableSlice &src, row_id_type *row_id_generated); | |||||
| /// \brief This function is used in preparation for batch fetching. | |||||
| /// It calculates how much memory we should allocate and which row id are present. | |||||
| /// \param[in/out] Pointer to vector of <CachePool::key_type, size_t> | |||||
| /// \param[in/out] mem_sz how much memory is required to batch fetch | |||||
| /// \return Status object | |||||
| Status PreBatchFetch(const std::vector<row_id_type> &v, std::vector<key_size_pair> *, int64_t *mem_sz); | |||||
| /// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded | /// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded | ||||
| /// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row. | /// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row. | ||||
| /// \param[in] v A vector of row id. | /// \param[in] v A vector of row id. | ||||
| /// \param[out] out A contiguous memory buffer that holds the requested rows. | /// \param[out] out A contiguous memory buffer that holds the requested rows. | ||||
| /// \return Status object | /// \return Status object | ||||
| Status BatchFetch(const std::vector<row_id_type> &v, MemGuard<uint8_t> *out) const; | |||||
| Status BatchFetch(const std::vector<row_id_type> &v, const std::vector<key_size_pair> &, WritableSlice *out) const; | |||||
| /// \brief Getter function | /// \brief Getter function | ||||
| /// \return Spilling path | /// \return Spilling path | ||||
| @@ -102,7 +116,7 @@ class CacheService : public Service { | |||||
| /// \brief Fetch schema | /// \brief Fetch schema | ||||
| /// \param out A contiguous memory that contains the serialized form of schema. | /// \param out A contiguous memory that contains the serialized form of schema. | ||||
| /// \return Status object | /// \return Status object | ||||
| Status FetchSchema(MemGuard<uint8_t> *out) const; | |||||
| Status FetchSchema(std::string *out) const; | |||||
| /// \brief Purge the content of a cache | /// \brief Purge the content of a cache | ||||
| /// \return Status object | /// \return Status object | ||||
| Status Purge(); | Status Purge(); | ||||
| @@ -60,10 +60,11 @@ table TensorRowIds { | |||||
| } | } | ||||
| /// Statistics returned from each cache service | /// Statistics returned from each cache service | ||||
| /// \note It must match CacheService::ServiceStat | |||||
| /// \note It must match CacheServiceStat | |||||
| table ServiceStatMsg { | table ServiceStatMsg { | ||||
| num_mem_cached:int64; | num_mem_cached:int64; | ||||
| num_disk_cached:int64; | num_disk_cached:int64; | ||||
| avg_cache_sz:int64; | |||||
| min_row_id:int64; | min_row_id:int64; | ||||
| max_row_id:int64; | max_row_id:int64; | ||||
| state:int8; | state:int8; | ||||
| @@ -79,3 +80,15 @@ table ColumnNameMsg { | |||||
| table SchemaMsg { | table SchemaMsg { | ||||
| column:[ColumnNameMsg]; | column:[ColumnNameMsg]; | ||||
| } | } | ||||
| /// Part of the CreateCacheRequest | |||||
| table CreateCacheRequestMsg { | |||||
| cache_mem_sz:int64; | |||||
| flag:uint32; | |||||
| } | |||||
| /// Return result of CreateCacheRequest | |||||
| table CreateCacheReplyMsg { | |||||
| connection_id:int64; | |||||
| cookie:string; | |||||
| } | |||||
| @@ -0,0 +1,45 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_STUB_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_STUB_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "proto/cache_grpc.pb.h" | |||||
| #include "minddata/dataset/engine/cache/cache_common.h" | |||||
| #include "minddata/dataset/engine/cache/cache_request.h" | |||||
| #include "minddata/dataset/util/service.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| class CacheClientGreeter : public Service { | |||||
| public: | |||||
| explicit CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_workers) {} | |||||
| ~CacheClientGreeter() override {} | |||||
| Status DoServiceStart() override { RETURN_STATUS_UNEXPECTED("Not supported"); } | |||||
| Status DoServiceStop() override { RETURN_STATUS_UNEXPECTED("Not supported"); } | |||||
| void *SharedMemoryBaseAddr() { return nullptr; } | |||||
| Status HandleRequest(std::shared_ptr<BaseRequest> rq) { RETURN_STATUS_UNEXPECTED("Not supported"); } | |||||
| Status AttachToSharedMemory(int32_t port, bool *local_bypass) { RETURN_STATUS_UNEXPECTED("Not supported"); } | |||||
| protected: | |||||
| private: | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_STUB_H_ | |||||
| @@ -16,6 +16,7 @@ | |||||
| #include "minddata/dataset/engine/datasetops/cache_base_op.h" | #include "minddata/dataset/engine/datasetops/cache_base_op.h" | ||||
| #include <iomanip> | #include <iomanip> | ||||
| #include <iostream> | #include <iostream> | ||||
| #include <utility> | |||||
| #include "minddata/dataset/engine/execution_tree.h" | #include "minddata/dataset/engine/execution_tree.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -47,22 +48,39 @@ Status CacheBase::Reset() { | |||||
| } | } | ||||
| CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, | CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, | ||||
| std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler) | std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler) | ||||
| : ParallelOp(num_workers, op_connector_size, sampler), | |||||
| cache_client_(cache_client), | |||||
| : ParallelOp(num_workers, op_connector_size, std::move(sampler)), | |||||
| row_cnt_(0), | |||||
| num_cache_miss_(0), | |||||
| cache_client_(std::move(cache_client)), | |||||
| rows_per_buffer_(rows_per_buf), | rows_per_buffer_(rows_per_buf), | ||||
| // We can cause deadlock if this internal Connector size is too small. | // We can cause deadlock if this internal Connector size is too small. | ||||
| keys_miss_(num_workers_, 1, connector_capacity_) { | |||||
| keys_miss_(num_workers_, 1, connector_capacity_), | |||||
| prefetch_size_(cache_client_->getPrefetchSize()) { | |||||
| io_block_queues_.Init(num_workers, op_connector_size); | io_block_queues_.Init(num_workers, op_connector_size); | ||||
| prefetch_queues_.Init(num_workers, op_connector_size); | |||||
| sampler_queue_ = std::make_unique<Queue<std::shared_ptr<Tensor>>>(op_connector_size); | |||||
| } | } | ||||
| // Common function to fetch samples from the sampler and send them using the io_block_queues to | // Common function to fetch samples from the sampler and send them using the io_block_queues to | ||||
| // the parallel workers | // the parallel workers | ||||
| Status CacheBase::FetchSamplesToWorkers() { | Status CacheBase::FetchSamplesToWorkers() { | ||||
| int64_t buf_cnt = 0; | int64_t buf_cnt = 0; | ||||
| int64_t wait_cnt = 0; | int64_t wait_cnt = 0; | ||||
| // Kick off several threads which will prefetch prefetch_size_ rows in advance. The rows_per_buffers_ | |||||
| // is too small (1 by default) and won't help performance. | |||||
| RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("Dispatcher", std::bind(&CacheBase::Dispatcher, this))); | |||||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CacheBase::Prefetcher, this, std::placeholders::_1))); | |||||
| // Instead of sending sampler id to WorkerEntry, we send them to the Prefetcher which will redirect them | |||||
| // to the WorkerEntry. | |||||
| do { | do { | ||||
| epoch_sync_.Clear(); | epoch_sync_.Clear(); | ||||
| if (AllowCacheMiss() && wait_cnt > 0) { | |||||
| MS_LOG(WARNING) << "Epoch: " << wait_cnt << " Cache Miss : " << num_cache_miss_ | |||||
| << " Total number of rows : " << row_cnt_; | |||||
| } | |||||
| num_cache_miss_ = 0; | |||||
| row_cnt_ = 0; | |||||
| ++wait_cnt; | |||||
| std::vector<row_id_type> keys; | std::vector<row_id_type> keys; | ||||
| int64_t row_cnt = 0; | |||||
| keys.reserve(rows_per_buffer_); | keys.reserve(rows_per_buffer_); | ||||
| std::unique_ptr<DataBuffer> sampler_buffer; | std::unique_ptr<DataBuffer> sampler_buffer; | ||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | ||||
| @@ -70,10 +88,13 @@ Status CacheBase::FetchSamplesToWorkers() { | |||||
| TensorRow sample_row; | TensorRow sample_row; | ||||
| RETURN_IF_NOT_OK(sampler_buffer->PopRow(&sample_row)); | RETURN_IF_NOT_OK(sampler_buffer->PopRow(&sample_row)); | ||||
| std::shared_ptr<Tensor> sample_ids = sample_row[0]; | std::shared_ptr<Tensor> sample_ids = sample_row[0]; | ||||
| // Send the sampler tensor to other thread for prefetching. We are using shared pointer so it | |||||
| // won't go out scope until it is really not in use. | |||||
| RETURN_IF_NOT_OK(sampler_queue_->Add(sample_ids)); | |||||
| for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); itr++) { | for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); itr++) { | ||||
| keys.push_back(*itr); | keys.push_back(*itr); | ||||
| ++row_cnt; | |||||
| if (row_cnt % rows_per_buffer_ == 0) { | |||||
| ++row_cnt_; | |||||
| if (row_cnt_ % rows_per_buffer_ == 0) { | |||||
| auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)); | auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)); | ||||
| RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk))); | RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk))); | ||||
| keys.clear(); | keys.clear(); | ||||
| @@ -90,7 +111,7 @@ Status CacheBase::FetchSamplesToWorkers() { | |||||
| io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | ||||
| // If repeat but the not last repeat, wait for reset. | // If repeat but the not last repeat, wait for reset. | ||||
| if (!IsLastIteration()) { | if (!IsLastIteration()) { | ||||
| MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << ++wait_cnt << " Buffer sent " << buf_cnt; | |||||
| MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << wait_cnt << " Buffer sent " << buf_cnt; | |||||
| RETURN_IF_NOT_OK(epoch_sync_.Wait()); | RETURN_IF_NOT_OK(epoch_sync_.Wait()); | ||||
| } else { | } else { | ||||
| // We can break out from the loop. | // We can break out from the loop. | ||||
| @@ -101,13 +122,21 @@ Status CacheBase::FetchSamplesToWorkers() { | |||||
| // Flow the eof before exit | // Flow the eof before exit | ||||
| RETURN_IF_NOT_OK( | RETURN_IF_NOT_OK( | ||||
| io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof))); | io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof))); | ||||
| // Ask all the workers to quit. | |||||
| // Shutdown threads | |||||
| std::shared_ptr<Tensor> empty; | |||||
| RETURN_IF_NOT_OK(sampler_queue_->Add(std::move(empty))); | |||||
| for (int32_t i = 0; i < num_workers_; i++) { | for (int32_t i = 0; i < num_workers_; i++) { | ||||
| RETURN_IF_NOT_OK( | RETURN_IF_NOT_OK( | ||||
| io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone))); | io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone))); | ||||
| } | } | ||||
| // Dump the last epoch result (approximately) without waiting for the worker threads to come back. | |||||
| if (AllowCacheMiss()) { | |||||
| MS_LOG(WARNING) << "Epoch: " << wait_cnt << " Cache Miss : " << num_cache_miss_ | |||||
| << " Total number of rows : " << row_cnt_; | |||||
| } | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheBase::FetchFromCache(int32_t worker_id) { | Status CacheBase::FetchFromCache(int32_t worker_id) { | ||||
| int64_t buffer_id = worker_id; | int64_t buffer_id = worker_id; | ||||
| std::unique_ptr<IOBlock> blk; | std::unique_ptr<IOBlock> blk; | ||||
| @@ -133,23 +162,16 @@ Status CacheBase::FetchFromCache(int32_t worker_id) { | |||||
| } | } | ||||
| std::unique_ptr<DataBuffer> db = std::make_unique<DataBuffer>(buffer_id, DataBuffer::kDeBFlagNone); | std::unique_ptr<DataBuffer> db = std::make_unique<DataBuffer>(buffer_id, DataBuffer::kDeBFlagNone); | ||||
| std::unique_ptr<TensorQTable> que = std::make_unique<TensorQTable>(); | std::unique_ptr<TensorQTable> que = std::make_unique<TensorQTable>(); | ||||
| TensorTable ttbl; | |||||
| RETURN_IF_NOT_OK(cache_client_->GetRows(keys, &ttbl)); | |||||
| auto row_it = ttbl.begin(); | |||||
| std::vector<row_id_type> cache_miss; | std::vector<row_id_type> cache_miss; | ||||
| cache_miss.reserve(keys.size()); | cache_miss.reserve(keys.size()); | ||||
| for (auto row_id : keys) { | for (auto row_id : keys) { | ||||
| auto &row = *row_it; | |||||
| TensorRow row; | |||||
| // Block until the row shows up in the pool. | |||||
| RETURN_IF_NOT_OK(prefetch_.PopFront(row_id, &row)); | |||||
| if (row.empty()) { | if (row.empty()) { | ||||
| if (AllowCacheMiss()) { | |||||
| cache_miss.push_back(row_id); | |||||
| } else { | |||||
| std::string errMsg = "Row id " + std::to_string(row_id) + " not found."; | |||||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||||
| } | |||||
| cache_miss.push_back(row_id); | |||||
| } | } | ||||
| que->push_back(std::move(row)); | que->push_back(std::move(row)); | ||||
| ++row_it; | |||||
| } | } | ||||
| db->set_tensor_table(std::move(que)); | db->set_tensor_table(std::move(que)); | ||||
| if (AllowCacheMiss()) { | if (AllowCacheMiss()) { | ||||
| @@ -162,12 +184,17 @@ Status CacheBase::FetchFromCache(int32_t worker_id) { | |||||
| } while (true); | } while (true); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheBase::RegisterResources() { | Status CacheBase::RegisterResources() { | ||||
| RETURN_IF_NOT_OK(epoch_sync_.Register(tree_->AllTasks())); | RETURN_IF_NOT_OK(epoch_sync_.Register(tree_->AllTasks())); | ||||
| RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); | RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); | ||||
| RETURN_IF_NOT_OK(prefetch_queues_.Register(tree_->AllTasks())); | |||||
| RETURN_IF_NOT_OK(sampler_queue_->Register(tree_->AllTasks())); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| CacheBase::~CacheBase() {} | |||||
| CacheBase::~CacheBase() = default; | |||||
| Status CacheBase::UpdateColumnMapFromCache() { | Status CacheBase::UpdateColumnMapFromCache() { | ||||
| Status rc; | Status rc; | ||||
| // Get the schema from the server. It may not be there yet. So tolerate the error. | // Get the schema from the server. It may not be there yet. So tolerate the error. | ||||
| @@ -180,5 +207,77 @@ Status CacheBase::UpdateColumnMapFromCache() { | |||||
| } | } | ||||
| return rc; | return rc; | ||||
| } | } | ||||
| Status CacheBase::Dispatcher() { | |||||
| TaskManager::FindMe()->Post(); | |||||
| int64_t buf_cnt = 0; | |||||
| int64_t num_row = 0; | |||||
| std::vector<row_id_type> keys; | |||||
| keys.reserve(prefetch_size_); | |||||
| do { | |||||
| keys.clear(); | |||||
| std::shared_ptr<Tensor> sample_ids; | |||||
| RETURN_IF_NOT_OK(sampler_queue_->PopFront(&sample_ids)); | |||||
| if (sample_ids == nullptr) { | |||||
| // A null shared pointer signal times to quit. | |||||
| // Also signal all prefetchers to quit. | |||||
| for (int32_t i = 0; i < num_workers_; i++) { | |||||
| RETURN_IF_NOT_OK( | |||||
| prefetch_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone))); | |||||
| } | |||||
| break; | |||||
| } | |||||
| // Now we distribute the sampler ids to each prefetcher according to the prefetch size. | |||||
| for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); itr++) { | |||||
| keys.push_back(*itr); | |||||
| ++num_row; | |||||
| if (num_row % prefetch_size_ == 0) { | |||||
| auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)); | |||||
| RETURN_IF_NOT_OK(prefetch_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk))); | |||||
| keys.clear(); | |||||
| } | |||||
| } | |||||
| // Send the remaining sample id | |||||
| if (!keys.empty()) { | |||||
| auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)); | |||||
| RETURN_IF_NOT_OK(prefetch_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk))); | |||||
| } | |||||
| } while (true); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheBase::Prefetcher(int32_t worker_id) { | |||||
| TaskManager::FindMe()->Post(); | |||||
| std::vector<row_id_type> prefetch_keys; | |||||
| prefetch_keys.reserve(prefetch_size_); | |||||
| do { | |||||
| prefetch_keys.clear(); | |||||
| std::unique_ptr<IOBlock> blk; | |||||
| RETURN_IF_NOT_OK(prefetch_queues_[worker_id]->PopFront(&blk)); | |||||
| RETURN_IF_NOT_OK(blk->GetKeys(&prefetch_keys)); | |||||
| if (prefetch_keys.empty()) { | |||||
| // Empty keys mean time to quit. | |||||
| break; | |||||
| } | |||||
| TensorTable ttbl; | |||||
| RETURN_IF_NOT_OK(cache_client_->GetRows(prefetch_keys, &ttbl)); | |||||
| auto row_it = ttbl.begin(); | |||||
| for (auto row_id : prefetch_keys) { | |||||
| auto &row = *row_it; | |||||
| if (row.empty()) { | |||||
| if (AllowCacheMiss()) { | |||||
| ++num_cache_miss_; | |||||
| } else { | |||||
| std::string errMsg = "Row id " + std::to_string(row_id) + " not found."; | |||||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||||
| } | |||||
| } | |||||
| // Put the prefetch row into the pool and wake up any WorkerEntry to wait for the row | |||||
| RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row))); | |||||
| ++row_it; | |||||
| } | |||||
| } while (true); | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,8 @@ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_ | #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_ | ||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_ | #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_ | ||||
| #include <atomic> | |||||
| #include <deque> | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <utility> | #include <utility> | ||||
| @@ -28,8 +30,9 @@ | |||||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" | #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" | ||||
| #include "minddata/dataset/util/queue.h" | #include "minddata/dataset/util/queue.h" | ||||
| #include "minddata/dataset/util/queue_map.h" | |||||
| #include "minddata/dataset/util/semaphore.h" | |||||
| #include "minddata/dataset/util/wait_post.h" | #include "minddata/dataset/util/wait_post.h" | ||||
| #include "minddata/dataset/engine/datasetops/cache_base_op.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| /// \brief This is the base class for CacheOp and CacheLookupOp which share many similarities. | /// \brief This is the base class for CacheOp and CacheLookupOp which share many similarities. | ||||
| @@ -82,10 +85,13 @@ class CacheBase : public ParallelOp { | |||||
| protected: | protected: | ||||
| constexpr static int32_t eoe_row_id = -1; | constexpr static int32_t eoe_row_id = -1; | ||||
| int64_t row_cnt_; | |||||
| std::atomic<int64_t> num_cache_miss_; | |||||
| std::shared_ptr<CacheClient> cache_client_; | std::shared_ptr<CacheClient> cache_client_; | ||||
| WaitPost epoch_sync_; | WaitPost epoch_sync_; | ||||
| int32_t rows_per_buffer_; | int32_t rows_per_buffer_; | ||||
| Connector<std::vector<row_id_type>> keys_miss_; | Connector<std::vector<row_id_type>> keys_miss_; | ||||
| QueueMap<row_id_type, TensorRow> prefetch_; | |||||
| /// \brief Common function to register resources for interrupt | /// \brief Common function to register resources for interrupt | ||||
| /// \note Derived should override this function for extra resources to be registered | /// \note Derived should override this function for extra resources to be registered | ||||
| @@ -103,7 +109,15 @@ class CacheBase : public ParallelOp { | |||||
| private: | private: | ||||
| constexpr static int32_t connector_capacity_ = 1024; | constexpr static int32_t connector_capacity_ = 1024; | ||||
| int32_t prefetch_size_; | |||||
| QueueList<std::unique_ptr<IOBlock>> io_block_queues_; | QueueList<std::unique_ptr<IOBlock>> io_block_queues_; | ||||
| QueueList<std::unique_ptr<IOBlock>> prefetch_queues_; | |||||
| std::unique_ptr<Queue<std::shared_ptr<Tensor>>> sampler_queue_; | |||||
| Status Dispatcher(); | |||||
| /// \brief Prefetcher. It prefetch the rows from cache server | |||||
| /// \return Status object. | |||||
| Status Prefetcher(int32_t worker_id); | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,8 +16,10 @@ | |||||
| #include "minddata/dataset/engine/datasetops/cache_merge_op.h" | #include "minddata/dataset/engine/datasetops/cache_merge_op.h" | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <chrono> | |||||
| #include <functional> | #include <functional> | ||||
| #include <iomanip> | #include <iomanip> | ||||
| #include <utility> | |||||
| #include "minddata/dataset/core/config_manager.h" | #include "minddata/dataset/core/config_manager.h" | ||||
| #include "minddata/dataset/core/constants.h" | #include "minddata/dataset/core/constants.h" | ||||
| #include "minddata/dataset/core/global_context.h" | #include "minddata/dataset/core/global_context.h" | ||||
| @@ -41,9 +43,13 @@ void CacheMergeOp::Print(std::ostream &out, bool show_all) const { | |||||
| out << "\n\n"; | out << "\n\n"; | ||||
| } | } | ||||
| } | } | ||||
| CacheMergeOp::CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners, | CacheMergeOp::CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners, | ||||
| std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<Sampler> &sampler) | std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<Sampler> &sampler) | ||||
| : ParallelOp(numWorkers, opConnectorSize, sampler), num_cleaners_(numCleaners), cache_client_(cache_client) {} | |||||
| : ParallelOp(numWorkers, opConnectorSize, sampler), | |||||
| num_cleaners_(numCleaners), | |||||
| cache_client_(std::move(cache_client)) {} | |||||
| Status CacheMergeOp::operator()() { | Status CacheMergeOp::operator()() { | ||||
| // A queue of row id to let cleaner send cache miss rows to the cache server | // A queue of row id to let cleaner send cache miss rows to the cache server | ||||
| // We don't want a small queue as this will block the parallel op workers. | // We don't want a small queue as this will block the parallel op workers. | ||||
| @@ -62,6 +68,7 @@ Status CacheMergeOp::operator()() { | |||||
| TaskManager::FindMe()->Post(); | TaskManager::FindMe()->Post(); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Each parallel worker will pop from the CacheHit stream. If there is a missing TensorRow, we will wait | // Each parallel worker will pop from the CacheHit stream. If there is a missing TensorRow, we will wait | ||||
| // until it shows up in the pool. | // until it shows up in the pool. | ||||
| Status CacheMergeOp::WorkerEntry(int32_t worker_id) { | Status CacheMergeOp::WorkerEntry(int32_t worker_id) { | ||||
| @@ -82,10 +89,8 @@ Status CacheMergeOp::WorkerEntry(int32_t worker_id) { | |||||
| RETURN_IF_NOT_OK(db_ptr->PopRow(&row)); | RETURN_IF_NOT_OK(db_ptr->PopRow(&row)); | ||||
| if (row.empty()) { | if (row.empty()) { | ||||
| auto row_id = row.getId(); | auto row_id = row.getId(); | ||||
| TensorRowRequest *rq = nullptr; | |||||
| RETURN_IF_NOT_OK(GetRq(row_id, &rq)); | |||||
| // Block until the row shows up in the pool. | // Block until the row shows up in the pool. | ||||
| RETURN_IF_NOT_OK(rq->Wait(&row)); | |||||
| RETURN_IF_NOT_OK(cache_miss_.PopFront(row_id, &row)); | |||||
| } | } | ||||
| tbl->push_back(std::move(row)); | tbl->push_back(std::move(row)); | ||||
| } | } | ||||
| @@ -97,6 +102,7 @@ Status CacheMergeOp::WorkerEntry(int32_t worker_id) { | |||||
| RETURN_IF_NOT_OK(EofReceived(worker_id)); | RETURN_IF_NOT_OK(EofReceived(worker_id)); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) { | Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) { | ||||
| TaskManager::FindMe()->Post(); | TaskManager::FindMe()->Post(); | ||||
| // We will simply pop TensorRow from the stream and insert them into the pool and | // We will simply pop TensorRow from the stream and insert them into the pool and | ||||
| @@ -123,17 +129,27 @@ Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) { | |||||
| std::string errMsg = "Expect positive row id: " + std::to_string(row_id); | std::string errMsg = "Expect positive row id: " + std::to_string(row_id); | ||||
| RETURN_STATUS_UNEXPECTED(errMsg); | RETURN_STATUS_UNEXPECTED(errMsg); | ||||
| } | } | ||||
| TensorRowRequest *rq = nullptr; | |||||
| // Technically number of this row shows up in the cache miss stream is equal to the number | |||||
| // of P() call. However the cleaner wants it too. So we need an extra copy. | |||||
| TensorRowCacheRequest *rq; | |||||
| RETURN_IF_NOT_OK(GetRq(row_id, &rq)); | RETURN_IF_NOT_OK(GetRq(row_id, &rq)); | ||||
| rq->WakeUpAny(std::move(row)); | |||||
| // Let the cleaner to flush out this row (async) to the cache server. | |||||
| RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id)); | |||||
| if (rq->GetState() == TensorRowCacheRequest::State::kEmpty) { | |||||
| // We will send the request async. But any error we most | |||||
| // likely ignore and continue. | |||||
| Status rc; | |||||
| rc = rq->AsyncSendCacheRequest(cache_client_, row); | |||||
| if (rc.IsOk()) { | |||||
| RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id)); | |||||
| } | |||||
| } | |||||
| RETURN_IF_NOT_OK(cache_miss_.Add(row_id, std::move(row))); | |||||
| } | } | ||||
| } | } | ||||
| RETURN_IF_NOT_OK(cache_missing_stream->GetNextBuffer(&db_ptr, workerId)); | RETURN_IF_NOT_OK(cache_missing_stream->GetNextBuffer(&db_ptr, workerId)); | ||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheMergeOp::Cleaner() { | Status CacheMergeOp::Cleaner() { | ||||
| TaskManager::FindMe()->Post(); | TaskManager::FindMe()->Post(); | ||||
| while (true) { | while (true) { | ||||
| @@ -142,45 +158,28 @@ Status CacheMergeOp::Cleaner() { | |||||
| if (row_id < 0) { | if (row_id < 0) { | ||||
| break; | break; | ||||
| } | } | ||||
| TensorRowRequest *rq = nullptr; | |||||
| // Locate the cache request | |||||
| TensorRowCacheRequest *rq; | |||||
| RETURN_IF_NOT_OK(GetRq(row_id, &rq)); | RETURN_IF_NOT_OK(GetRq(row_id, &rq)); | ||||
| if (rq->GetState() == TensorRowRequest::State::kClean) { | |||||
| // If already flushed, move on to the next one. | |||||
| // If already flushed, move on to the next one. | |||||
| if (rq->GetState() == TensorRowCacheRequest::State::kClean) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| TensorRow row; | |||||
| RETURN_IF_NOT_OK(rq->Release(&row)); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!row.empty(), "Programming error."); | |||||
| Status rc = cache_client_->WriteRow(row); | |||||
| // Bad rc should not bring down the pipeline | |||||
| Status rc = rq->CheckCacheResult(); | |||||
| if (rc.IsError()) { | if (rc.IsError()) { | ||||
| MS_LOG(WARNING) << "Cache not successful." << rc.ToString(); | |||||
| // If interrupt, time to quit. | |||||
| if (rc.get_code() == StatusCode::kInterrupted) { | |||||
| return Status::OK(); | |||||
| } | |||||
| MS_LOG(INFO) << "Cache row not successful: " << rc.ToString(); | |||||
| // Bad rc should not bring down the pipeline. We will simply continue and | |||||
| // change the state back to empty. We don't need a CAS from CLEAN back to EMPTY. | |||||
| rq->SetState(TensorRowCacheRequest::State::kEmpty); | |||||
| } | } | ||||
| rq->SetState(TensorRowRequest::State::kClean); | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheMergeOp::GetRq(row_id_type row_id, CacheMergeOp::TensorRowRequest **out) { | |||||
| RETURN_UNEXPECTED_IF_NULL(out); | |||||
| std::unique_lock<std::mutex> lck(mux_); | |||||
| auto it = cache_miss_map_.find(row_id); | |||||
| if (it != cache_miss_map_.end()) { | |||||
| *out = it->second.GetMutablePointer(); | |||||
| } else { | |||||
| // We will create a new one. | |||||
| auto alloc = Services::GetAllocator<TensorRowRequest>(); | |||||
| auto r = cache_miss_map_.emplace(row_id, MemGuard<TensorRowRequest, Allocator<TensorRowRequest>>(alloc)); | |||||
| if (r.second) { | |||||
| auto &mem = r.first->second; | |||||
| RETURN_IF_NOT_OK(mem.allocate(1, row_id)); | |||||
| *out = mem.GetMutablePointer(); | |||||
| } else { | |||||
| RETURN_STATUS_UNEXPECTED("Map insert fail."); | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheMergeOp::PrepareNodePostAction() { // Run any common code from super class first before adding our own | Status CacheMergeOp::PrepareNodePostAction() { // Run any common code from super class first before adding our own | ||||
| // specific logic | // specific logic | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(child_.size() == 2, "Incorrect number of children"); | CHECK_FAIL_RETURN_UNEXPECTED(child_.size() == 2, "Incorrect number of children"); | ||||
| @@ -199,6 +198,7 @@ Status CacheMergeOp::PrepareNodePostAction() { // Run any common code from supe | |||||
| RETURN_IF_NOT_OK(rc); | RETURN_IF_NOT_OK(rc); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheMergeOp::ComputeColMap() { | Status CacheMergeOp::ComputeColMap() { | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(child_[kCacheMissChildIdx] != nullptr, "Cache miss stream empty"); | CHECK_FAIL_RETURN_UNEXPECTED(child_[kCacheMissChildIdx] != nullptr, "Cache miss stream empty"); | ||||
| if (column_name_id_map().empty()) { | if (column_name_id_map().empty()) { | ||||
| @@ -207,53 +207,13 @@ Status CacheMergeOp::ComputeColMap() { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!column_name_id_map().empty(), "No column map detected"); | CHECK_FAIL_RETURN_UNEXPECTED(!column_name_id_map().empty(), "No column map detected"); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheMergeOp::TensorRowRequest::Wait(TensorRow *out) { | |||||
| RETURN_UNEXPECTED_IF_NULL(out); | |||||
| // Block until the missing row is in the pool. | |||||
| RETURN_IF_NOT_OK(use_count_.P()); | |||||
| std::unique_lock<std::mutex> lck(dq_mux_); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!row_.empty(), "Programming error"); | |||||
| *out = std::move(row_.front()); | |||||
| row_.pop_front(); | |||||
| return Status::OK(); | |||||
| } | |||||
| void CacheMergeOp::TensorRowRequest::WakeUpAny(TensorRow &&row) { | |||||
| std::unique_lock<std::mutex> lck(dq_mux_); | |||||
| // Technically number of this row shows up in the cache miss stream is equal to the number | |||||
| // of P() call. However the cleaner wants it too. So we need an extra copy. | |||||
| if (GetState() == State::kEmpty) { | |||||
| // We will do a deep copy | |||||
| for (auto &ts : row) { | |||||
| std::shared_ptr<Tensor> out_ts; | |||||
| Tensor::CreateFromTensor(ts, &out_ts); | |||||
| cleaner_copy_.push_back(out_ts); | |||||
| } | |||||
| cleaner_copy_.setId(row.getId()); | |||||
| // Change the state to dirty | |||||
| SetState(State::kDirty); | |||||
| } | |||||
| row_.push_back(std::move(row)); | |||||
| // Bump up the use count by 1. This wake up any parallel worker which is waiting | |||||
| // for this row. | |||||
| use_count_.V(); | |||||
| } | |||||
| Status CacheMergeOp::TensorRowRequest::Release(TensorRow *out) { | |||||
| RETURN_UNEXPECTED_IF_NULL(out); | |||||
| // We are not holding any mutex here because the cleaner isn't really touching the deque row_. | |||||
| // In case we have multiple cleaners and they all see the copy, only one of them will | |||||
| // get it. | |||||
| auto expected = State::kDirty; | |||||
| if (st_.compare_exchange_strong(expected, State::kClean)) { | |||||
| *out = std::move(cleaner_copy_); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Builder constructor. Creates the builder object. | // Builder constructor. Creates the builder object. | ||||
| CacheMergeOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) { | CacheMergeOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) { | ||||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | ||||
| build_num_workers_ = cfg->num_parallel_workers(); | build_num_workers_ = cfg->num_parallel_workers(); | ||||
| build_op_connector_size_ = cfg->op_connector_size(); | build_op_connector_size_ = cfg->op_connector_size(); | ||||
| build_num_cleaners_ = 1; | |||||
| build_num_cleaners_ = cfg->num_parallel_workers(); | |||||
| } | } | ||||
| // Check if the required parameters are set by the builder. | // Check if the required parameters are set by the builder. | ||||
| @@ -311,5 +271,60 @@ Status CacheMergeOp::EofReceived(int32_t worker_id) { | |||||
| MS_LOG(DEBUG) << "Cache merge sending eof"; | MS_LOG(DEBUG) << "Cache merge sending eof"; | ||||
| return DatasetOp::EofReceived(worker_id); | return DatasetOp::EofReceived(worker_id); | ||||
| } | } | ||||
| Status CacheMergeOp::GetRq(row_id_type row_id, CacheMergeOp::TensorRowCacheRequest **out) { | |||||
| RETURN_UNEXPECTED_IF_NULL(out); | |||||
| std::unique_lock<std::mutex> lock(mux_); | |||||
| auto it = io_request_.find(row_id); | |||||
| if (it != io_request_.end()) { | |||||
| *out = it->second.GetMutablePointer(); | |||||
| } else { | |||||
| // We will create a new one. | |||||
| auto alloc = Services::GetAllocator<TensorRowCacheRequest>(); | |||||
| auto r = io_request_.emplace(row_id, MemGuard<TensorRowCacheRequest, Allocator<TensorRowCacheRequest>>(alloc)); | |||||
| if (r.second) { | |||||
| auto &mem = r.first->second; | |||||
| RETURN_IF_NOT_OK(mem.allocate(1)); | |||||
| *out = mem.GetMutablePointer(); | |||||
| } else { | |||||
| RETURN_STATUS_UNEXPECTED("Map insert fail."); | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheMergeOp::TensorRowCacheRequest::AsyncSendCacheRequest(const std::shared_ptr<CacheClient> &cc, | |||||
| const TensorRow &row) { | |||||
| auto expected = State::kEmpty; | |||||
| if (st_.compare_exchange_strong(expected, State::kDirty)) { | |||||
| // We will do a deep copy but write directly into CacheRequest protobuf or shared memory | |||||
| Status rc; | |||||
| cleaner_copy_ = | |||||
| std::make_shared<CacheRowRequest>(cc->server_connection_id_, cc->cookie(), cc->SupportLocalClient()); | |||||
| rc = cleaner_copy_->SerializeCacheRowRequest(cc.get(), row); | |||||
| if (rc.IsOk()) { | |||||
| // Send the request async. The cleaner will check the return code. | |||||
| rc = cc->PushRequest(cleaner_copy_); | |||||
| } | |||||
| if (rc.IsError()) { | |||||
| // Clean up the shared pointer and reset the state back to empty | |||||
| cleaner_copy_.reset(); | |||||
| st_ = State::kEmpty; | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheMergeOp::TensorRowCacheRequest::CheckCacheResult() { | |||||
| auto expected = State::kDirty; | |||||
| if (st_.compare_exchange_strong(expected, State::kClean)) { | |||||
| // Success or not, we will release the memory. | |||||
| // We simply move it out of the structure and let it go out of scope. | |||||
| auto cache_request = std::move(cleaner_copy_); | |||||
| RETURN_IF_NOT_OK(cache_request->Wait()); | |||||
| return Status::OK(); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,7 @@ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_ | #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_ | ||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_ | #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_ | ||||
| #include <algorithm> | |||||
| #include <atomic> | #include <atomic> | ||||
| #include <deque> | #include <deque> | ||||
| #include <map> | #include <map> | ||||
| @@ -28,6 +29,7 @@ | |||||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | #include "minddata/dataset/engine/datasetops/parallel_op.h" | ||||
| #include "minddata/dataset/engine/dataset_iterator.h" | #include "minddata/dataset/engine/dataset_iterator.h" | ||||
| #include "minddata/dataset/util/queue.h" | #include "minddata/dataset/util/queue.h" | ||||
| #include "minddata/dataset/util/queue_map.h" | |||||
| #include "minddata/dataset/util/semaphore.h" | #include "minddata/dataset/util/semaphore.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -36,28 +38,34 @@ namespace dataset { | |||||
| /// stream | /// stream | ||||
| class CacheMergeOp : public ParallelOp { | class CacheMergeOp : public ParallelOp { | ||||
| public: | public: | ||||
| // Some handshake structures among the main thread, cleaner threads and parallel op threads. | |||||
| class TensorRowRequest { | |||||
| // Some handshake structures between CacheMissWorkerEntry and Cleaner | |||||
| class TensorRowCacheRequest { | |||||
| public: | public: | ||||
| enum class State : uint8_t { | enum class State : uint8_t { | ||||
| kEmpty = 0, // No row in the deque | |||||
| kEmpty = 0, // Initial state. Row hasn't arrived from cache miss stream yet. | |||||
| kDirty = 1, // Cleaner hasn't flushed it to the cache server yet. | kDirty = 1, // Cleaner hasn't flushed it to the cache server yet. | ||||
| kClean = 2 // The row has been flushed already. | kClean = 2 // The row has been flushed already. | ||||
| }; | }; | ||||
| explicit TensorRowRequest(row_id_type id) : st_(State::kEmpty), use_count_(0) {} | |||||
| ~TensorRowRequest() = default; | |||||
| TensorRowCacheRequest() : st_(State::kEmpty) {} | |||||
| ~TensorRowCacheRequest() = default; | |||||
| /// Getter and Setter of the state | |||||
| State GetState() const { return st_; } | State GetState() const { return st_; } | ||||
| void SetState(State newState) { st_ = newState; } | void SetState(State newState) { st_ = newState; } | ||||
| Status Wait(TensorRow *out); | |||||
| void WakeUpAny(TensorRow &&row); | |||||
| Status Release(TensorRow *out); | |||||
| /// Take a tensor row and send rpc call to the server async | |||||
| /// \param cc Cache client of the CacheMergeOp | |||||
| /// \param row TensorRow to be sent to the server | |||||
| /// \return Status object | |||||
| /// \note Thread safe | |||||
| Status AsyncSendCacheRequest(const std::shared_ptr<CacheClient> &cc, const TensorRow &row); | |||||
| /// \brief We send the row to the server async so the CacheMissWorkerEntry can continue. | |||||
| /// It is the cleaner that will check the result. | |||||
| /// \return Status object | |||||
| Status CheckCacheResult(); | |||||
| private: | private: | ||||
| std::mutex dq_mux_; | |||||
| std::atomic<State> st_; | std::atomic<State> st_; | ||||
| Semaphore use_count_; | |||||
| std::deque<TensorRow> row_; | |||||
| TensorRow cleaner_copy_; | |||||
| std::shared_ptr<CacheRowRequest> cleaner_copy_; | |||||
| }; | }; | ||||
| constexpr static int kCacheHitChildIdx = 0; // Cache hit stream | constexpr static int kCacheHitChildIdx = 0; // Cache hit stream | ||||
| @@ -80,6 +88,8 @@ class CacheMergeOp : public ParallelOp { | |||||
| /// \return Builder setter method returns reference to the builder. | /// \return Builder setter method returns reference to the builder. | ||||
| Builder &SetNumWorkers(int32_t num_workers) { | Builder &SetNumWorkers(int32_t num_workers) { | ||||
| build_num_workers_ = num_workers; | build_num_workers_ = num_workers; | ||||
| // Adjust the number of cleaners to match the number of workers | |||||
| build_num_cleaners_ = std::max(build_num_cleaners_, build_num_workers_); | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| @@ -159,7 +169,6 @@ class CacheMergeOp : public ParallelOp { | |||||
| /// \param workerId | /// \param workerId | ||||
| /// \return Status object | /// \return Status object | ||||
| Status CacheMissWorkerEntry(int32_t workerId); | Status CacheMissWorkerEntry(int32_t workerId); | ||||
| Status GetRq(row_id_type row_id, TensorRowRequest **); | |||||
| /// \brief Base-class override for NodePass pre-visit acceptor | /// \brief Base-class override for NodePass pre-visit acceptor | ||||
| /// \param[in] p The node to visit | /// \param[in] p The node to visit | ||||
| @@ -188,11 +197,18 @@ class CacheMergeOp : public ParallelOp { | |||||
| private: | private: | ||||
| std::mutex mux_; | std::mutex mux_; | ||||
| std::map<row_id_type, MemGuard<TensorRowRequest, Allocator<TensorRowRequest>>> cache_miss_map_; | |||||
| QueueMap<row_id_type, TensorRow> cache_miss_; | |||||
| std::map<row_id_type, MemGuard<TensorRowCacheRequest, Allocator<TensorRowCacheRequest>>> io_request_; | |||||
| std::unique_ptr<Queue<row_id_type>> io_que_; | std::unique_ptr<Queue<row_id_type>> io_que_; | ||||
| std::shared_ptr<CacheClient> cache_client_; | std::shared_ptr<CacheClient> cache_client_; | ||||
| int32_t num_cleaners_; | int32_t num_cleaners_; | ||||
| /// \brief Locate the cache request from the io_request_ map | |||||
| /// \param row_id | |||||
| /// \param out pointer to the cache request | |||||
| /// \return Status object | |||||
| Status GetRq(row_id_type row_id, TensorRowCacheRequest **out); | |||||
| /// \brief These are the entry functions for the cleaner threads. Each cleaner is responsible for | /// \brief These are the entry functions for the cleaner threads. Each cleaner is responsible for | ||||
| /// moving cache miss TensorRow into the CacheServer. | /// moving cache miss TensorRow into the CacheServer. | ||||
| /// \return Status object | /// \return Status object | ||||
| @@ -142,7 +142,7 @@ Status CacheOp::WaitForCachingAllRows() { | |||||
| } | } | ||||
| // Get statistics from the server, and if we are not the one to create the cache, | // Get statistics from the server, and if we are not the one to create the cache, | ||||
| // wait until the state changed from build phase to fetch base. | // wait until the state changed from build phase to fetch base. | ||||
| CacheClient::ServiceStat stat{}; | |||||
| CacheServiceStat stat{}; | |||||
| bool BuildPhaseDone = true; | bool BuildPhaseDone = true; | ||||
| do { | do { | ||||
| RETURN_IF_NOT_OK(cache_client_->GetStat(&stat)); | RETURN_IF_NOT_OK(cache_client_->GetStat(&stat)); | ||||
| @@ -157,6 +157,7 @@ Status CacheOp::WaitForCachingAllRows() { | |||||
| MS_LOG(INFO) << "Number of rows cached: " << num_rows_; | MS_LOG(INFO) << "Number of rows cached: " << num_rows_; | ||||
| MS_LOG(INFO) << "Number of rows cached in memory : " << stat.num_mem_cached; | MS_LOG(INFO) << "Number of rows cached in memory : " << stat.num_mem_cached; | ||||
| MS_LOG(INFO) << "Number of rows spilled to disk : " << stat.num_disk_cached; | MS_LOG(INFO) << "Number of rows spilled to disk : " << stat.num_disk_cached; | ||||
| MS_LOG(INFO) << "Average cache size : " << stat.avg_cache_sz; | |||||
| // Now all rows are cached and we have done a sync point check up. Next phase is | // Now all rows are cached and we have done a sync point check up. Next phase is | ||||
| // is pick up fetch input from sampler and pass up to the caller. | // is pick up fetch input from sampler and pass up to the caller. | ||||
| RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); | RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); | ||||
| @@ -392,6 +392,13 @@ uint32_t DatasetOp::GenerateCRC(const std::shared_ptr<DatasetOp> &op) { | |||||
| ss_str = std::regex_replace(ss_str, std::regex("Num workers.*\n"), ""); | ss_str = std::regex_replace(ss_str, std::regex("Num workers.*\n"), ""); | ||||
| ss_str = std::regex_replace(ss_str, std::regex("\\[workers.*\\]"), ""); | ss_str = std::regex_replace(ss_str, std::regex("\\[workers.*\\]"), ""); | ||||
| // Filter out tcp/ip information | |||||
| ss_str = std::regex_replace(ss_str, std::regex("Hostname.*\n"), ""); | |||||
| ss_str = std::regex_replace(ss_str, std::regex("Port.*\n"), ""); | |||||
| ss_str = std::regex_replace(ss_str, std::regex("Number of rpc workers.*\n"), ""); | |||||
| ss_str = std::regex_replace(ss_str, std::regex("Prefetch size.*\n"), ""); | |||||
| ss_str = std::regex_replace(ss_str, std::regex("Local client support.*\n"), ""); | |||||
| // Filter out Number of rows when generating the check sum | // Filter out Number of rows when generating the check sum | ||||
| ss_str = std::regex_replace(ss_str, std::regex("Number of rows.*\n"), ""); | ss_str = std::regex_replace(ss_str, std::regex("Number of rows.*\n"), ""); | ||||
| @@ -73,6 +73,7 @@ enum class StatusCode : char { | |||||
| kProfilingError = 10, | kProfilingError = 10, | ||||
| kBoundingBoxOutOfBounds = 11, | kBoundingBoxOutOfBounds = 11, | ||||
| kBoundingBoxInvalidShape = 12, | kBoundingBoxInvalidShape = 12, | ||||
| kSyntaxError = 13, | |||||
| // Make this error code the last one. Add new error code above it. | // Make this error code the last one. Add new error code above it. | ||||
| kUnexpectedError = 127 | kUnexpectedError = 127 | ||||
| }; | }; | ||||
| @@ -168,9 +168,9 @@ class MemGuard { | |||||
| size_t GetSizeInBytes() const { return n_ * sizeof(T); } | size_t GetSizeInBytes() const { return n_ * sizeof(T); } | ||||
| private: | private: | ||||
| size_t n_; | |||||
| allocator alloc_; | allocator alloc_; | ||||
| std::unique_ptr<T[]> ptr_; | std::unique_ptr<T[]> ptr_; | ||||
| size_t n_; | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -27,20 +27,20 @@ | |||||
| #define ARENA_WALL_OVERHEAD_SZ 32 | #define ARENA_WALL_OVERHEAD_SZ 32 | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| // This is a memory arena based on a treap data structure. | |||||
| // The constructor of the Arena takes the size of the initial memory size (in MB). | |||||
| // Internally we divide the memory into multiple blocks. Each block is 64 bytes. | |||||
| // The treap contains all the free blocks with the relative memory address as key | |||||
| // and the size of the block as priority. | |||||
| // | |||||
| // Initially the treap has only one root which is the whole memory piece. | |||||
| // | |||||
| // For memory suballocation, we pop the root node of the treap which contains the largest free block. | |||||
| // We allocate what we need and return the rest back to the treap. We search for the first fit instead | |||||
| // of the best fit so to give us a constant time in memory allocation. | |||||
| // | |||||
| // When a block of memory is freed. It is joined with the blocks before and after (if they are available) to | |||||
| // form a bigger block. | |||||
| /// This is a memory arena based on a treap data structure. | |||||
| /// The constructor of the Arena takes the size of the initial memory size (in MB). | |||||
| /// Internally we divide the memory into multiple blocks. Each block is 64 bytes. | |||||
| /// The treap contains all the free blocks with the relative memory address as key | |||||
| /// and the size of the block as priority. | |||||
| /// | |||||
| /// Initially the treap has only one root which is the whole memory piece. | |||||
| /// | |||||
| /// For memory suballocation, we pop the root node of the treap which contains the largest free block. | |||||
| /// We allocate what we need and return the rest back to the treap. We search for the first fit instead | |||||
| /// of the best fit so to give us a constant time in memory allocation. | |||||
| /// | |||||
| /// When a block of memory is freed. It is joined with the blocks before and after (if they are available) to | |||||
| /// form a bigger block. | |||||
| class Arena : public MemoryPool { | class Arena : public MemoryPool { | ||||
| public: | public: | ||||
| Arena(const Arena &) = delete; | Arena(const Arena &) = delete; | ||||
| @@ -78,7 +78,7 @@ class Arena : public MemoryPool { | |||||
| static Status CreateArena(std::shared_ptr<Arena> *p_ba, size_t val_in_MB = 4096); | static Status CreateArena(std::shared_ptr<Arena> *p_ba, size_t val_in_MB = 4096); | ||||
| private: | |||||
| protected: | |||||
| std::mutex mux_; | std::mutex mux_; | ||||
| Treap<uint64_t, uint64_t> tr_; | Treap<uint64_t, uint64_t> tr_; | ||||
| void *ptr_; | void *ptr_; | ||||
| @@ -140,13 +140,22 @@ Path CachePool::GetSpillPath() const { | |||||
| } | } | ||||
| CachePool::CacheStat CachePool::GetStat() const { | CachePool::CacheStat CachePool::GetStat() const { | ||||
| CacheStat cs{0}; | CacheStat cs{0}; | ||||
| int64_t total_sz = 0; | |||||
| for (auto &it : *tree_) { | for (auto &it : *tree_) { | ||||
| total_sz += it.sz; | |||||
| if (it.ptr != nullptr) { | if (it.ptr != nullptr) { | ||||
| ++cs.num_mem_cached; | ++cs.num_mem_cached; | ||||
| } else { | } else { | ||||
| ++cs.num_disk_cached; | ++cs.num_disk_cached; | ||||
| } | } | ||||
| } | } | ||||
| if (total_sz > 0) { | |||||
| // integer arithmetic. NO need to cast to float or double. | |||||
| cs.average_cache_sz = total_sz / (cs.num_disk_cached + cs.num_mem_cached); | |||||
| if (cs.average_cache_sz == 0) { | |||||
| cs.average_cache_sz = 1; | |||||
| } | |||||
| } | |||||
| return cs; | return cs; | ||||
| } | } | ||||
| Status CachePool::Spill(CachePool::DataLocator *dl) { | Status CachePool::Spill(CachePool::DataLocator *dl) { | ||||
| @@ -82,6 +82,7 @@ class CachePool : public Service { | |||||
| struct CacheStat { | struct CacheStat { | ||||
| int64_t num_mem_cached; | int64_t num_mem_cached; | ||||
| int64_t num_disk_cached; | int64_t num_disk_cached; | ||||
| int64_t average_cache_sz; | |||||
| }; | }; | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| @@ -0,0 +1,127 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_MAP_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_MAP_H_ | |||||
| #include <deque> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <mutex> | |||||
| #include "minddata/dataset/util/allocator.h" | |||||
| #include "minddata/dataset/util/semaphore.h" | |||||
| #include "minddata/dataset/util/services.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| template <typename K, typename T> | |||||
| /// \brief QueueMap is like a Queue but instead of there is a map of deque<T>. | |||||
| /// Consumer will block if the corresponding deque is empty. | |||||
| /// Producer can add an element of type T with key of type K to the map and | |||||
| /// wake up any waiting consumer. | |||||
| /// \tparam K key type | |||||
| /// \tparam T payload of the map | |||||
| class QueueMap { | |||||
| public: | |||||
| using key_type = K; | |||||
| using value_type = T; | |||||
| QueueMap() = default; | |||||
| virtual ~QueueMap() = default; | |||||
| /// Add an element <key, T> to the map and wake up any consumer that is waiting | |||||
| /// \param key | |||||
| /// \param payload | |||||
| /// \return Status object | |||||
| virtual Status Add(key_type key, T &&payload) { | |||||
| RequestQueue *rq = nullptr; | |||||
| RETURN_IF_NOT_OK(GetRq(key, &rq)); | |||||
| RETURN_IF_NOT_OK(rq->WakeUpAny(std::move(payload))); | |||||
| return Status::OK(); | |||||
| } | |||||
| /// Pop the front of the deque with key. Block if the deque is empty. | |||||
| virtual Status PopFront(key_type key, T *out) { | |||||
| RequestQueue *rq = nullptr; | |||||
| RETURN_IF_NOT_OK(GetRq(key, &rq)); | |||||
| RETURN_IF_NOT_OK(rq->Wait(out)); | |||||
| return Status::OK(); | |||||
| } | |||||
| protected: | |||||
| /// This is a handshake structure between producer and consumer | |||||
| class RequestQueue { | |||||
| public: | |||||
| RequestQueue() : use_count_(0) {} | |||||
| ~RequestQueue() = default; | |||||
| Status Wait(T *out) { | |||||
| RETURN_UNEXPECTED_IF_NULL(out); | |||||
| // Block until the missing row is in the pool. | |||||
| RETURN_IF_NOT_OK(use_count_.P()); | |||||
| std::unique_lock<std::mutex> lck(dq_mux_); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!row_.empty(), "Programming error"); | |||||
| *out = std::move(row_.front()); | |||||
| row_.pop_front(); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status WakeUpAny(T &&row) { | |||||
| std::unique_lock<std::mutex> lck(dq_mux_); | |||||
| row_.push_back(std::move(row)); | |||||
| // Bump up the use count by 1. This wake up any parallel worker which is waiting | |||||
| // for this row. | |||||
| use_count_.V(); | |||||
| return Status::OK(); | |||||
| } | |||||
| private: | |||||
| std::mutex dq_mux_; | |||||
| Semaphore use_count_; | |||||
| std::deque<T> row_; | |||||
| }; | |||||
| /// Create or locate an element with matching key | |||||
| /// \param key | |||||
| /// \param out | |||||
| /// \return Status object | |||||
| Status GetRq(key_type key, RequestQueue **out) { | |||||
| RETURN_UNEXPECTED_IF_NULL(out); | |||||
| std::unique_lock<std::mutex> lck(mux_); | |||||
| auto it = all_.find(key); | |||||
| if (it != all_.end()) { | |||||
| *out = it->second.GetMutablePointer(); | |||||
| } else { | |||||
| // We will create a new one. | |||||
| auto alloc = Services::GetAllocator<RequestQueue>(); | |||||
| auto r = all_.emplace(key, MemGuard<RequestQueue, Allocator<RequestQueue>>(alloc)); | |||||
| if (r.second) { | |||||
| auto &mem = r.first->second; | |||||
| RETURN_IF_NOT_OK(mem.allocate(1)); | |||||
| *out = mem.GetMutablePointer(); | |||||
| } else { | |||||
| RETURN_STATUS_UNEXPECTED("Map insert fail."); | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| private: | |||||
| std::mutex mux_; | |||||
| std::map<K, MemGuard<RequestQueue, Allocator<RequestQueue>>> all_; | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_MAP_H_ | |||||
| @@ -22,7 +22,6 @@ | |||||
| #include <stdlib.h> | #include <stdlib.h> | ||||
| #endif | #endif | ||||
| #include <unistd.h> | #include <unistd.h> | ||||
| #include "minddata/dataset/engine/cache/cache_server.h" | |||||
| #include "minddata/dataset/util/circular_pool.h" | #include "minddata/dataset/util/circular_pool.h" | ||||
| #include "minddata/dataset/util/random.h" | #include "minddata/dataset/util/random.h" | ||||
| #include "minddata/dataset/util/task_manager.h" | #include "minddata/dataset/util/task_manager.h" | ||||
| @@ -59,35 +58,15 @@ std::string Services::GetUniqueID() { | |||||
| return std::string(buffer, UNIQUEID_LEN); | return std::string(buffer, UNIQUEID_LEN); | ||||
| } | } | ||||
| TaskManager &Services::getTaskMgrInstance() { | |||||
| Services &sm = GetInstance(); | |||||
| return *(static_cast<TaskManager *>(sm.sa_[kSlotTaskMgr_])); | |||||
| } | |||||
| CacheServer &Services::getCacheServer() { | |||||
| Services &sm = GetInstance(); | |||||
| return *(static_cast<CacheServer *>(sm.sa_[kSlotCacheMgr_])); | |||||
| } | |||||
| Status Services::CreateAllInstances() { | Status Services::CreateAllInstances() { | ||||
| // In order, TaskMgr, BufferMgr | |||||
| Status rc; | |||||
| sa_[kSlotTaskMgr_] = new (&rc, pool_) TaskManager(); | |||||
| RETURN_IF_NOT_OK(rc); | |||||
| rc = sa_[kSlotTaskMgr_]->ServiceStart(); | |||||
| RETURN_IF_NOT_OK(rc); | |||||
| // TODO(jesse) : Get the parameters from config file. Right now spill to /tmp and spawn 3 workers | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| sa_[kSlotCacheMgr_] = new (&rc, pool_) CacheServer("/tmp", 3); | |||||
| RETURN_IF_NOT_OK(rc); | |||||
| rc = sa_[kSlotCacheMgr_]->ServiceStart(); | |||||
| #else | |||||
| sa_[kSlotCacheMgr_] = nullptr; | |||||
| #endif | |||||
| return rc; | |||||
| // First one is always the TaskManager | |||||
| RETURN_IF_NOT_OK(TaskManager::CreateInstance()); | |||||
| TaskManager &tm = TaskManager::GetInstance(); | |||||
| RETURN_IF_NOT_OK(tm.ServiceStart()); | |||||
| return Status::OK(); | |||||
| } | } | ||||
| Services::Services() : pool_(nullptr), sa_{nullptr} { | |||||
| Services::Services() : pool_(nullptr) { | |||||
| Status rc = CircularPool::CreateCircularPool(&pool_, -1, 16, true); // each arena 16M | Status rc = CircularPool::CreateCircularPool(&pool_, -1, 16, true); // each arena 16M | ||||
| if (rc.IsError()) { | if (rc.IsError()) { | ||||
| std::terminate(); | std::terminate(); | ||||
| @@ -95,22 +74,11 @@ Services::Services() : pool_(nullptr), sa_{nullptr} { | |||||
| } | } | ||||
| Services::~Services() noexcept { | Services::~Services() noexcept { | ||||
| try { | |||||
| // In reverse order | |||||
| CacheServer *cs = static_cast<CacheServer *>(sa_[kSlotCacheMgr_]); | |||||
| if (cs != nullptr) { | |||||
| (void)cs->ServiceStop(); | |||||
| cs->~CacheServer(); | |||||
| pool_->Deallocate(cs); | |||||
| } | |||||
| TaskManager *tm = static_cast<TaskManager *>(sa_[kSlotTaskMgr_]); | |||||
| if (tm != nullptr) { | |||||
| (void)tm->ServiceStop(); | |||||
| tm->~TaskManager(); | |||||
| pool_->Deallocate(tm); | |||||
| } | |||||
| } catch (const std::exception &e) { | |||||
| // Do nothing. | |||||
| // Shutdown in reverse order. | |||||
| auto n = hook_.size(); | |||||
| while (n > 0) { | |||||
| hook_.pop_back(); | |||||
| n = hook_.size(); | |||||
| } | } | ||||
| } | } | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -16,9 +16,11 @@ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_SERVICES_H_ | #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_SERVICES_H_ | ||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_SERVICES_H_ | #define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_SERVICES_H_ | ||||
| #include <algorithm> | |||||
| #include <memory> | #include <memory> | ||||
| #include <mutex> | #include <mutex> | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | |||||
| #include "minddata/dataset/util/memory_pool.h" | #include "minddata/dataset/util/memory_pool.h" | ||||
| #include "minddata/dataset/util/allocator.h" | #include "minddata/dataset/util/allocator.h" | ||||
| #include "minddata/dataset/util/service.h" | #include "minddata/dataset/util/service.h" | ||||
| @@ -27,7 +29,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| class TaskManager; | class TaskManager; | ||||
| class CacheServer; | |||||
| class Services { | class Services { | ||||
| public: | public: | ||||
| static Status CreateInstance() { | static Status CreateInstance() { | ||||
| @@ -59,10 +61,6 @@ class Services { | |||||
| ~Services() noexcept; | ~Services() noexcept; | ||||
| static TaskManager &getTaskMgrInstance(); | |||||
| static CacheServer &getCacheServer(); | |||||
| std::shared_ptr<MemoryPool> GetServiceMemPool() { return pool_; } | std::shared_ptr<MemoryPool> GetServiceMemPool() { return pool_; } | ||||
| #if !defined(_WIN32) && !defined(_WIN64) | #if !defined(_WIN32) && !defined(_WIN64) | ||||
| @@ -80,19 +78,29 @@ class Services { | |||||
| return Allocator<T>(Services::GetInstance().GetServiceMemPool()); | return Allocator<T>(Services::GetInstance().GetServiceMemPool()); | ||||
| } | } | ||||
| /// \brief Add a new service to the start up list. | |||||
| /// \tparam T Class that implements Service | |||||
| /// \return Status object and where the service is located in the hook_ list | |||||
| template <typename T, typename... Args> | |||||
| Status AddHook(T **out, Args &&... args) { | |||||
| RETURN_UNEXPECTED_IF_NULL(out); | |||||
| try { | |||||
| (*out) = new T(std::forward<Args>(args)...); | |||||
| std::unique_ptr<T> svc(*out); | |||||
| hook_.push_back(std::move(svc)); | |||||
| } catch (const std::bad_alloc &e) { | |||||
| return Status(StatusCode::kOutOfMemory); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| private: | private: | ||||
| static std::once_flag init_instance_flag_; | static std::once_flag init_instance_flag_; | ||||
| static std::unique_ptr<Services> instance_; | static std::unique_ptr<Services> instance_; | ||||
| // A small pool used for small objects that last until the | // A small pool used for small objects that last until the | ||||
| // Services Manager shuts down. Used by all sub-services. | // Services Manager shuts down. Used by all sub-services. | ||||
| std::shared_ptr<MemoryPool> pool_; | std::shared_ptr<MemoryPool> pool_; | ||||
| // We use pointers here instead of unique_ptr because we | |||||
| // want to have ultimate control on the order of | |||||
| // construction and destruction. | |||||
| static constexpr int kSlotTaskMgr_ = 0; | |||||
| static constexpr int kSlotCacheMgr_ = 1; | |||||
| static constexpr int kNumServices_ = 2; | |||||
| Service *sa_[kNumServices_]; | |||||
| std::vector<std::unique_ptr<Service>> hook_; | |||||
| Services(); | Services(); | ||||
| @@ -86,6 +86,7 @@ class ReadableSlice { | |||||
| class WritableSlice : public ReadableSlice { | class WritableSlice : public ReadableSlice { | ||||
| public: | public: | ||||
| friend class StorageContainer; | friend class StorageContainer; | ||||
| friend class CacheService; | |||||
| /// \brief Default constructor | /// \brief Default constructor | ||||
| WritableSlice() : ReadableSlice(), mutable_data_(nullptr) {} | WritableSlice() : ReadableSlice(), mutable_data_(nullptr) {} | ||||
| /// \brief This form of a constructor takes a pointer and its size. | /// \brief This form of a constructor takes a pointer and its size. | ||||
| @@ -48,6 +48,9 @@ std::string CodeAsString(const StatusCode c) { | |||||
| case StatusCode::kProfilingError: | case StatusCode::kProfilingError: | ||||
| s = "Error encountered while profiling"; | s = "Error encountered while profiling"; | ||||
| break; | break; | ||||
| case StatusCode::kSyntaxError: | |||||
| s = "Syntax error"; | |||||
| break; | |||||
| case StatusCode::kUnexpectedError: | case StatusCode::kUnexpectedError: | ||||
| default: | default: | ||||
| s = "Unexpected error"; | s = "Unexpected error"; | ||||
| @@ -80,6 +80,7 @@ enum class StatusCode : char { | |||||
| kProfilingError = 10, | kProfilingError = 10, | ||||
| kBoundingBoxOutOfBounds = 11, | kBoundingBoxOutOfBounds = 11, | ||||
| kBoundingBoxInvalidShape = 12, | kBoundingBoxInvalidShape = 12, | ||||
| kSyntaxError = 13, | |||||
| // Make this error code the last one. Add new error code above it. | // Make this error code the last one. Add new error code above it. | ||||
| kUnexpectedError = 127 | kUnexpectedError = 127 | ||||
| }; | }; | ||||
| @@ -21,6 +21,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| TaskManager *TaskManager::instance_ = nullptr; | |||||
| std::once_flag TaskManager::init_instance_flag_; | |||||
| // This takes the same parameter as Task constructor. | // This takes the same parameter as Task constructor. | ||||
| Status TaskManager::CreateAsyncTask(const std::string &my_name, const std::function<Status()> &f, TaskGroup *vg, | Status TaskManager::CreateAsyncTask(const std::string &my_name, const std::function<Status()> &f, TaskGroup *vg, | ||||
| Task **task) { | Task **task) { | ||||
| @@ -54,7 +54,16 @@ class TaskManager : public Service { | |||||
| TaskManager &operator=(const TaskManager &) = delete; | TaskManager &operator=(const TaskManager &) = delete; | ||||
| static TaskManager &GetInstance() noexcept { return Services::getTaskMgrInstance(); } | |||||
| static Status CreateInstance() { | |||||
| std::call_once(init_instance_flag_, [&]() -> Status { | |||||
| auto &svcManager = Services::GetInstance(); | |||||
| RETURN_IF_NOT_OK(svcManager.AddHook(&instance_)); | |||||
| return Status::OK(); | |||||
| }); | |||||
| return Status::OK(); | |||||
| } | |||||
| static TaskManager &GetInstance() noexcept { return *instance_; } | |||||
| Status DoServiceStart() override; | Status DoServiceStart() override; | ||||
| @@ -96,6 +105,8 @@ class TaskManager : public Service { | |||||
| Status WatchDog(); | Status WatchDog(); | ||||
| private: | private: | ||||
| static std::once_flag init_instance_flag_; | |||||
| static TaskManager *instance_; | |||||
| RWLock lru_lock_; | RWLock lru_lock_; | ||||
| SpinLock free_lock_; | SpinLock free_lock_; | ||||
| SpinLock tg_lock_; | SpinLock tg_lock_; | ||||
| @@ -25,15 +25,22 @@ class DatasetCache: | |||||
| A client to interface with tensor caching service | A client to interface with tensor caching service | ||||
| """ | """ | ||||
| def __init__(self, session_id=None, size=0, spilling=False): | |||||
| def __init__(self, session_id=None, size=0, spilling=False, port=50052, prefetch_size=20): | |||||
| check_uint32(session_id, "session_id") | check_uint32(session_id, "session_id") | ||||
| check_uint64(size, "size") | check_uint64(size, "size") | ||||
| type_check(spilling, (bool,), "spilling") | type_check(spilling, (bool,), "spilling") | ||||
| check_uint32(port, "port") | |||||
| check_uint32(prefetch_size, "prefetch size") | |||||
| self.session_id = session_id | self.session_id = session_id | ||||
| self.size = size | self.size = size | ||||
| self.spilling = spilling | self.spilling = spilling | ||||
| self.cache_client = CacheClient(session_id, size, spilling) | |||||
| self.port = port | |||||
| self.prefetch_size = prefetch_size | |||||
| self.cache_client = CacheClient(session_id, size, spilling, port, prefetch_size) | |||||
| def GetStat(self): | |||||
| return self.cache_client.GetStat() | |||||
| def __deepcopy__(self, memodict): | def __deepcopy__(self, memodict): | ||||
| if id(self) in memodict: | if id(self) in memodict: | ||||
| @@ -44,5 +51,7 @@ class DatasetCache: | |||||
| new_cache.session_id = copy.deepcopy(self.session_id, memodict) | new_cache.session_id = copy.deepcopy(self.session_id, memodict) | ||||
| new_cache.spilling = copy.deepcopy(self.spilling, memodict) | new_cache.spilling = copy.deepcopy(self.spilling, memodict) | ||||
| new_cache.size = copy.deepcopy(self.size, memodict) | new_cache.size = copy.deepcopy(self.size, memodict) | ||||
| new_cache.port = copy.deepcopy(self.port, memodict) | |||||
| new_cache.prefetch_size = copy.deepcopy(self.prefetch_size, memodict) | |||||
| new_cache.cache_client = self.cache_client | new_cache.cache_client = self.cache_client | ||||
| return new_cache | return new_cache | ||||
| @@ -43,13 +43,18 @@ class MindDataTestCacheOp : public UT::DatasetOpTesting { | |||||
| } | } | ||||
| }; | }; | ||||
| TEST_F(MindDataTestCacheOp, TestCacheServer) { | |||||
| TEST_F(MindDataTestCacheOp, DISABLED_TestCacheServer) { | |||||
| Status rc; | Status rc; | ||||
| CacheClient myClient(1, 0, true); // use arbitrary session of 1, size of 0, spilling is true | |||||
| CacheClient::Builder builder; | |||||
| // use arbitrary session of 1, size of 0, spilling// is true | |||||
| builder.SetSessionId(1).SetCacheMemSz(0).SetSpill(true); | |||||
| std::shared_ptr<CacheClient> myClient; | |||||
| rc = builder.Build(&myClient); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| // cksum value of 1 for CreateCache here...normally you do not directly create a cache and the cksum arg is generated. | // cksum value of 1 for CreateCache here...normally you do not directly create a cache and the cksum arg is generated. | ||||
| rc = myClient.CreateCache(1, true); | |||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| std::cout << myClient << std::endl; | |||||
| rc = myClient->CreateCache(1, true); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| std::cout << *myClient << std::endl; | |||||
| // Create a schema using the C api's | // Create a schema using the C api's | ||||
| int32_t rank = 0; // not used | int32_t rank = 0; // not used | ||||
| @@ -68,11 +73,11 @@ TEST_F(MindDataTestCacheOp, TestCacheServer) { | |||||
| std::unordered_map<std::string, int32_t> map; | std::unordered_map<std::string, int32_t> map; | ||||
| rc = testSchema->GetColumnNameMap(&map); | rc = testSchema->GetColumnNameMap(&map); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| // Test the CacheSchema api | // Test the CacheSchema api | ||||
| rc = myClient.CacheSchema(map); | |||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| rc = myClient->CacheSchema(map); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| // Create a tensor, take a snapshot and restore it back, and compare. | // Create a tensor, take a snapshot and restore it back, and compare. | ||||
| std::shared_ptr<Tensor> t; | std::shared_ptr<Tensor> t; | ||||
| @@ -88,48 +93,54 @@ TEST_F(MindDataTestCacheOp, TestCacheServer) { | |||||
| TensorRow row; | TensorRow row; | ||||
| row.push_back(t); | row.push_back(t); | ||||
| int64_t row_id; | int64_t row_id; | ||||
| rc = myClient.WriteRow(row, &row_id); | |||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| rc = myClient->WriteRow(row, &row_id); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| // Switch off build phase. | // Switch off build phase. | ||||
| rc = myClient.BuildPhaseDone(); | |||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| rc = myClient->BuildPhaseDone(); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| // Now restore from cache. | // Now restore from cache. | ||||
| row.clear(); | row.clear(); | ||||
| rc = myClient.GetRows({row_id}, &tbl); | |||||
| rc = myClient->GetRows({row_id}, &tbl); | |||||
| row = tbl.front(); | row = tbl.front(); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| auto r = row.front(); | auto r = row.front(); | ||||
| std::cout << *r << std::endl; | std::cout << *r << std::endl; | ||||
| // Compare | // Compare | ||||
| bool cmp = (*t == *r); | bool cmp = (*t == *r); | ||||
| EXPECT_TRUE(cmp); | |||||
| ASSERT_TRUE(cmp); | |||||
| // Get back the schema and verify | // Get back the schema and verify | ||||
| std::unordered_map<std::string, int32_t> map_out; | std::unordered_map<std::string, int32_t> map_out; | ||||
| rc = myClient.FetchSchema(&map_out); | |||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| rc = myClient->FetchSchema(&map_out); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| cmp = (map_out == map); | cmp = (map_out == map); | ||||
| EXPECT_TRUE(cmp); | |||||
| ASSERT_TRUE(cmp); | |||||
| // Test Purge and Destroy | // Test Purge and Destroy | ||||
| rc = myClient.PurgeCache(); | |||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| rc = myClient.DestroyCache(); | |||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| rc = myClient->PurgeCache(); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rc = myClient->DestroyCache(); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| } | } | ||||
| TEST_F(MindDataTestCacheOp, TestConcurrencyRequest) { | |||||
| TEST_F(MindDataTestCacheOp, DISABLED_TestConcurrencyRequest) { | |||||
| // Clear the rc of the master thread if any | // Clear the rc of the master thread if any | ||||
| (void)TaskManager::GetMasterThreadRc(); | (void)TaskManager::GetMasterThreadRc(); | ||||
| TaskGroup vg; | TaskGroup vg; | ||||
| Status rc; | Status rc; | ||||
| CacheClient myClient(1, 1, true); // use arbitrary session of 1, size 1, spilling is true | |||||
| // use arbitrary session of 1, size 1, spilling is true | |||||
| CacheClient::Builder builder; | |||||
| // use arbitrary session of 1, size of 0, spilling// is true | |||||
| builder.SetSessionId(1).SetCacheMemSz(1).SetSpill(true); | |||||
| std::shared_ptr<CacheClient> myClient; | |||||
| rc = builder.Build(&myClient); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| // cksum value of 1 for CreateCache here...normally you do not directly create a cache and the cksum arg is generated. | // cksum value of 1 for CreateCache here...normally you do not directly create a cache and the cksum arg is generated. | ||||
| rc = myClient.CreateCache(1, true); | |||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| std::cout << myClient << std::endl; | |||||
| rc = myClient->CreateCache(1, true); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| std::cout << *myClient << std::endl; | |||||
| std::shared_ptr<Tensor> t; | std::shared_ptr<Tensor> t; | ||||
| Tensor::CreateEmpty(TensorShape({2, 3}), DataType(DataType::DE_UINT64), &t); | Tensor::CreateEmpty(TensorShape({2, 3}), DataType(DataType::DE_UINT64), &t); | ||||
| t->SetItemAt<uint64_t>({0, 0}, 1); | t->SetItemAt<uint64_t>({0, 0}, 1); | ||||
| @@ -146,19 +157,19 @@ TEST_F(MindDataTestCacheOp, TestConcurrencyRequest) { | |||||
| Status vg_rc = vg.CreateAsyncTask("Test agent", [&myClient, &row]() -> Status { | Status vg_rc = vg.CreateAsyncTask("Test agent", [&myClient, &row]() -> Status { | ||||
| TaskManager::FindMe()->Post(); | TaskManager::FindMe()->Post(); | ||||
| for (auto i = 0; i < 500; i++) { | for (auto i = 0; i < 500; i++) { | ||||
| RETURN_IF_NOT_OK(myClient.WriteRow(row)); | |||||
| RETURN_IF_NOT_OK(myClient->WriteRow(row)); | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| }); | }); | ||||
| EXPECT_TRUE(vg_rc.IsOk()); | |||||
| ASSERT_TRUE(vg_rc.IsOk()); | |||||
| } | } | ||||
| ASSERT_TRUE(vg.join_all().IsOk()); | ASSERT_TRUE(vg.join_all().IsOk()); | ||||
| ASSERT_TRUE(vg.GetTaskErrorIfAny().IsOk()); | ASSERT_TRUE(vg.GetTaskErrorIfAny().IsOk()); | ||||
| rc = myClient.BuildPhaseDone(); | |||||
| rc = myClient->BuildPhaseDone(); | |||||
| ASSERT_TRUE(rc.IsOk()); | ASSERT_TRUE(rc.IsOk()); | ||||
| // Get statistics from the server. | // Get statistics from the server. | ||||
| CacheClient::ServiceStat stat{}; | |||||
| rc = myClient.GetStat(&stat); | |||||
| CacheServiceStat stat{}; | |||||
| rc = myClient->GetStat(&stat); | |||||
| ASSERT_TRUE(rc.IsOk()); | ASSERT_TRUE(rc.IsOk()); | ||||
| std::cout << stat.min_row_id << ":" << stat.max_row_id << ":" << stat.num_mem_cached << ":" << stat.num_disk_cached | std::cout << stat.min_row_id << ":" << stat.max_row_id << ":" << stat.num_mem_cached << ":" << stat.num_disk_cached | ||||
| << "\n"; | << "\n"; | ||||
| @@ -168,15 +179,15 @@ TEST_F(MindDataTestCacheOp, TestConcurrencyRequest) { | |||||
| for (auto i = stat.min_row_id; i <= stat.max_row_id; ++i) { | for (auto i = stat.min_row_id; i <= stat.max_row_id; ++i) { | ||||
| tbl.clear(); | tbl.clear(); | ||||
| row.clear(); | row.clear(); | ||||
| rc = myClient.GetRows({i}, &tbl); | |||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| rc = myClient->GetRows({i}, &tbl); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| row = tbl.front(); | row = tbl.front(); | ||||
| auto r = row.front(); | auto r = row.front(); | ||||
| bool cmp = (*t == *r); | bool cmp = (*t == *r); | ||||
| EXPECT_TRUE(cmp); | |||||
| ASSERT_TRUE(cmp); | |||||
| } | } | ||||
| rc = myClient.DestroyCache(); | |||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| rc = myClient->DestroyCache(); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| } | } | ||||
| // Simple test with a repeated cache op over random data producer | // Simple test with a repeated cache op over random data producer | ||||
| @@ -187,7 +198,7 @@ TEST_F(MindDataTestCacheOp, TestConcurrencyRequest) { | |||||
| // | | // | | ||||
| // RandomDataOp | // RandomDataOp | ||||
| // | // | ||||
| TEST_F(MindDataTestCacheOp, TestRandomDataCache1) { | |||||
| TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) { | |||||
| Status rc; | Status rc; | ||||
| int32_t rank = 0; // not used | int32_t rank = 0; // not used | ||||
| MS_LOG(INFO) << "UT test TestRandomDataCache1"; | MS_LOG(INFO) << "UT test TestRandomDataCache1"; | ||||
| @@ -218,13 +229,18 @@ TEST_F(MindDataTestCacheOp, TestRandomDataCache1) { | |||||
| .SetDataSchema(std::move(testSchema)) | .SetDataSchema(std::move(testSchema)) | ||||
| .SetTotalRows(50) // 50 samples for now | .SetTotalRows(50) // 50 samples for now | ||||
| .Build(&myRandomDataOp); | .Build(&myRandomDataOp); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rc = myTree->AssociateNode(myRandomDataOp); | rc = myTree->AssociateNode(myRandomDataOp); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| // CacheOp | // CacheOp | ||||
| // size of 0, spilling is true | // size of 0, spilling is true | ||||
| std::shared_ptr<CacheClient> myClient = std::make_shared<CacheClient>(1, 0, true); | |||||
| CacheClient::Builder builder; | |||||
| // use arbitrary session of 1, size of 0, spilling// is true | |||||
| builder.SetSessionId(1).SetCacheMemSz(0).SetSpill(true); | |||||
| std::shared_ptr<CacheClient> myClient; | |||||
| rc = builder.Build(&myClient); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| std::shared_ptr<CacheOp> myCacheOp; | std::shared_ptr<CacheOp> myCacheOp; | ||||
| int64_t num_samples = 0; | int64_t num_samples = 0; | ||||
| @@ -236,29 +252,29 @@ TEST_F(MindDataTestCacheOp, TestRandomDataCache1) { | |||||
| .SetRowsPerBuffer(4) | .SetRowsPerBuffer(4) | ||||
| .SetSampler(std::move(seq_sampler)) | .SetSampler(std::move(seq_sampler)) | ||||
| .Build(&myCacheOp); | .Build(&myCacheOp); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rc = myTree->AssociateNode(myCacheOp); | rc = myTree->AssociateNode(myCacheOp); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| // RepeatOp | // RepeatOp | ||||
| uint32_t numRepeats = 4; | uint32_t numRepeats = 4; | ||||
| std::shared_ptr<RepeatOp> myRepeatOp; | std::shared_ptr<RepeatOp> myRepeatOp; | ||||
| rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp); | rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rc = myTree->AssociateNode(myRepeatOp); | rc = myTree->AssociateNode(myRepeatOp); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| // Assign tree relations and root | // Assign tree relations and root | ||||
| rc = myRepeatOp->AddChild(myCacheOp); | rc = myRepeatOp->AddChild(myCacheOp); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rc = myCacheOp->AddChild(myRandomDataOp); | rc = myCacheOp->AddChild(myRandomDataOp); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rc = myTree->AssignRoot(myRepeatOp); | rc = myTree->AssignRoot(myRepeatOp); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| MS_LOG(INFO) << "Launching tree and begin iteration"; | MS_LOG(INFO) << "Launching tree and begin iteration"; | ||||
| rc = myTree->Prepare(); | rc = myTree->Prepare(); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| // quick check to see what tree looks like | // quick check to see what tree looks like | ||||
| std::ostringstream ss; | std::ostringstream ss; | ||||
| @@ -268,24 +284,24 @@ TEST_F(MindDataTestCacheOp, TestRandomDataCache1) { | |||||
| std::cout << *myClient << std::endl; | std::cout << *myClient << std::endl; | ||||
| rc = myTree->Launch(); | rc = myTree->Launch(); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| // Start the loop of reading tensors from our pipeline | // Start the loop of reading tensors from our pipeline | ||||
| DatasetIterator dI(myTree); | DatasetIterator dI(myTree); | ||||
| TensorRow tensorList; | TensorRow tensorList; | ||||
| rc = dI.FetchNextTensorRow(&tensorList); | rc = dI.FetchNextTensorRow(&tensorList); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| int rowCount = 0; | int rowCount = 0; | ||||
| while (!tensorList.empty()) { | while (!tensorList.empty()) { | ||||
| // Don't display these rows, just count them | // Don't display these rows, just count them | ||||
| MS_LOG(INFO) << "Row fetched #: " << rowCount; | MS_LOG(INFO) << "Row fetched #: " << rowCount; | ||||
| rc = dI.FetchNextTensorRow(&tensorList); | rc = dI.FetchNextTensorRow(&tensorList); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rowCount++; | rowCount++; | ||||
| } | } | ||||
| ASSERT_EQ(rowCount, 200); | ASSERT_EQ(rowCount, 200); | ||||
| rc = myClient->DestroyCache(); | rc = myClient->DestroyCache(); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| } | } | ||||
| //// Simple test with a repeated cache op over random data producer. | //// Simple test with a repeated cache op over random data producer. | ||||
| @@ -297,7 +313,7 @@ TEST_F(MindDataTestCacheOp, TestRandomDataCache1) { | |||||
| //// | | //// | | ||||
| //// RandomDataOp | //// RandomDataOp | ||||
| //// | //// | ||||
| TEST_F(MindDataTestCacheOp, TestRandomDataCacheSpill) { | |||||
| TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) { | |||||
| Status rc; | Status rc; | ||||
| int32_t rank = 0; // not used | int32_t rank = 0; // not used | ||||
| MS_LOG(INFO) << "UT test TestRandomDataCacheSpill"; | MS_LOG(INFO) << "UT test TestRandomDataCacheSpill"; | ||||
| @@ -328,15 +344,20 @@ TEST_F(MindDataTestCacheOp, TestRandomDataCacheSpill) { | |||||
| .SetDataSchema(std::move(testSchema)) | .SetDataSchema(std::move(testSchema)) | ||||
| .SetTotalRows(10) | .SetTotalRows(10) | ||||
| .Build(&myRandomDataOp); | .Build(&myRandomDataOp); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rc = myTree->AssociateNode(myRandomDataOp); | rc = myTree->AssociateNode(myRandomDataOp); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| // CacheOp | // CacheOp | ||||
| int64_t num_samples = 0; | int64_t num_samples = 0; | ||||
| int64_t start_index = 0; | int64_t start_index = 0; | ||||
| auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | ||||
| std::shared_ptr<CacheClient> myClient = std::make_shared<CacheClient>(1, 4, true); | |||||
| CacheClient::Builder builder; | |||||
| // use arbitrary session of 1, size of 0, spilling// is true | |||||
| builder.SetSessionId(1).SetCacheMemSz(4).SetSpill(true); | |||||
| std::shared_ptr<CacheClient> myClient; | |||||
| rc = builder.Build(&myClient); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| std::shared_ptr<CacheOp> myCacheOp; | std::shared_ptr<CacheOp> myCacheOp; | ||||
| rc = CacheOp::Builder() | rc = CacheOp::Builder() | ||||
| .SetNumWorkers(4) | .SetNumWorkers(4) | ||||
| @@ -344,60 +365,65 @@ TEST_F(MindDataTestCacheOp, TestRandomDataCacheSpill) { | |||||
| .SetRowsPerBuffer(3) | .SetRowsPerBuffer(3) | ||||
| .SetSampler(std::move(seq_sampler)) | .SetSampler(std::move(seq_sampler)) | ||||
| .Build(&myCacheOp); | .Build(&myCacheOp); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rc = myTree->AssociateNode(myCacheOp); | rc = myTree->AssociateNode(myCacheOp); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| // RepeatOp | // RepeatOp | ||||
| uint32_t numRepeats = 4; | uint32_t numRepeats = 4; | ||||
| std::shared_ptr<RepeatOp> myRepeatOp; | std::shared_ptr<RepeatOp> myRepeatOp; | ||||
| rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp); | rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rc = myTree->AssociateNode(myRepeatOp); | rc = myTree->AssociateNode(myRepeatOp); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| // Assign tree relations and root | // Assign tree relations and root | ||||
| rc = myRepeatOp->AddChild(myCacheOp); | rc = myRepeatOp->AddChild(myCacheOp); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rc = myCacheOp->AddChild(myRandomDataOp); | rc = myCacheOp->AddChild(myRandomDataOp); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rc = myTree->AssignRoot(myRepeatOp); | rc = myTree->AssignRoot(myRepeatOp); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| MS_LOG(INFO) << "Launching tree and begin iteration"; | MS_LOG(INFO) << "Launching tree and begin iteration"; | ||||
| rc = myTree->Prepare(); | rc = myTree->Prepare(); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| std::cout << *myClient << std::endl; | std::cout << *myClient << std::endl; | ||||
| rc = myTree->Launch(); | rc = myTree->Launch(); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| // Start the loop of reading tensors from our pipeline | // Start the loop of reading tensors from our pipeline | ||||
| DatasetIterator dI(myTree); | DatasetIterator dI(myTree); | ||||
| TensorRow tensorList; | TensorRow tensorList; | ||||
| rc = dI.FetchNextTensorRow(&tensorList); | rc = dI.FetchNextTensorRow(&tensorList); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| int rowCount = 0; | int rowCount = 0; | ||||
| while (!tensorList.empty()) { | while (!tensorList.empty()) { | ||||
| // Don't display these rows, just count them | // Don't display these rows, just count them | ||||
| MS_LOG(INFO) << "Row fetched #: " << rowCount; | MS_LOG(INFO) << "Row fetched #: " << rowCount; | ||||
| rc = dI.FetchNextTensorRow(&tensorList); | rc = dI.FetchNextTensorRow(&tensorList); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rowCount++; | rowCount++; | ||||
| } | } | ||||
| ASSERT_EQ(rowCount, 40); | ASSERT_EQ(rowCount, 40); | ||||
| rc = myClient->DestroyCache(); | rc = myClient->DestroyCache(); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| } | } | ||||
| TEST_F(MindDataTestCacheOp, TestImageFolderCacheMerge) { | |||||
| TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) { | |||||
| Status rc; | Status rc; | ||||
| int64_t num_samples = 0; | int64_t num_samples = 0; | ||||
| int64_t start_index = 0; | int64_t start_index = 0; | ||||
| auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | ||||
| std::shared_ptr<CacheClient> myClient = std::make_shared<CacheClient>(1, 0, true); | |||||
| CacheClient::Builder ccbuilder; | |||||
| // use arbitrary session of 1, size of 0, spilling// is true | |||||
| ccbuilder.SetSessionId(1).SetCacheMemSz(0).SetSpill(true); | |||||
| std::shared_ptr<CacheClient> myClient; | |||||
| rc = ccbuilder.Build(&myClient); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| // In a mappable dataset, it uses a complex interactions of cache lookup op and cache merge op. | // In a mappable dataset, it uses a complex interactions of cache lookup op and cache merge op. | ||||
| // Rather than manually build this, the way to do it is to choose the position of the cache in the tree by | // Rather than manually build this, the way to do it is to choose the position of the cache in the tree by | ||||
| @@ -417,44 +443,44 @@ TEST_F(MindDataTestCacheOp, TestImageFolderCacheMerge) { | |||||
| .SetRecursive(true) | .SetRecursive(true) | ||||
| .SetImageFolderDir(datasets_root_path_ + "/testPK/data"); | .SetImageFolderDir(datasets_root_path_ + "/testPK/data"); | ||||
| rc = builder.Build(&so); | rc = builder.Build(&so); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| // RepeatOp | // RepeatOp | ||||
| uint32_t numRepeats = 4; | uint32_t numRepeats = 4; | ||||
| std::shared_ptr<RepeatOp> myRepeatOp; | std::shared_ptr<RepeatOp> myRepeatOp; | ||||
| rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp); | rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| auto myTree = std::make_shared<ExecutionTree>(); | auto myTree = std::make_shared<ExecutionTree>(); | ||||
| rc = myTree->AssociateNode(so); | rc = myTree->AssociateNode(so); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rc = myTree->AssociateNode(myCacheOp); | rc = myTree->AssociateNode(myCacheOp); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rc = myTree->AssociateNode(myRepeatOp); | rc = myTree->AssociateNode(myRepeatOp); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rc = myTree->AssignRoot(myRepeatOp); | rc = myTree->AssignRoot(myRepeatOp); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rc = myRepeatOp->AddChild(myCacheOp); | rc = myRepeatOp->AddChild(myCacheOp); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rc = myCacheOp->AddChild(so); | rc = myCacheOp->AddChild(so); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rc = myTree->Prepare(); | rc = myTree->Prepare(); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rc = myTree->Launch(); | rc = myTree->Launch(); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| // Start the loop of reading tensors from our pipeline | // Start the loop of reading tensors from our pipeline | ||||
| DatasetIterator dI(myTree); | DatasetIterator dI(myTree); | ||||
| TensorRow tensorList; | TensorRow tensorList; | ||||
| rc = dI.FetchNextTensorRow(&tensorList); | rc = dI.FetchNextTensorRow(&tensorList); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| int rowCount = 0; | int rowCount = 0; | ||||
| while (!tensorList.empty()) { | while (!tensorList.empty()) { | ||||
| rc = dI.FetchNextTensorRow(&tensorList); | rc = dI.FetchNextTensorRow(&tensorList); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| if (rc.IsError()) { | if (rc.IsError()) { | ||||
| std::cout << rc << std::endl; | std::cout << rc << std::endl; | ||||
| break; | break; | ||||
| @@ -464,7 +490,7 @@ TEST_F(MindDataTestCacheOp, TestImageFolderCacheMerge) { | |||||
| ASSERT_EQ(rowCount, 176); | ASSERT_EQ(rowCount, 176); | ||||
| std::cout << "Row count : " << rowCount << std::endl; | std::cout << "Row count : " << rowCount << std::endl; | ||||
| rc = myClient->DestroyCache(); | rc = myClient->DestroyCache(); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| } | } | ||||
| //// Simple test with a repeated cache op over random data producer. | //// Simple test with a repeated cache op over random data producer. | ||||
| @@ -480,7 +506,7 @@ TEST_F(MindDataTestCacheOp, TestImageFolderCacheMerge) { | |||||
| //// | | //// | | ||||
| //// RandomDataOp | //// RandomDataOp | ||||
| //// | //// | ||||
| TEST_F(MindDataTestCacheOp, TestCacheInheritSampler) { | |||||
| TEST_F(MindDataTestCacheOp, DISABLED_TestCacheInheritSampler) { | |||||
| Status rc; | Status rc; | ||||
| int32_t rank = 0; // not used | int32_t rank = 0; // not used | ||||
| MS_LOG(INFO) << "UT test TestCacheInheritSampler"; | MS_LOG(INFO) << "UT test TestCacheInheritSampler"; | ||||
| @@ -517,57 +543,62 @@ TEST_F(MindDataTestCacheOp, TestCacheInheritSampler) { | |||||
| .SetTotalRows(10) | .SetTotalRows(10) | ||||
| .SetSampler(std::move(seq_sampler)) | .SetSampler(std::move(seq_sampler)) | ||||
| .Build(&myRandomDataOp); | .Build(&myRandomDataOp); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rc = myTree->AssociateNode(myRandomDataOp); | rc = myTree->AssociateNode(myRandomDataOp); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| // CacheOp | // CacheOp | ||||
| std::shared_ptr<CacheClient> myClient = std::make_shared<CacheClient>(1, 4, true); | |||||
| CacheClient::Builder ccbuilder; | |||||
| // use arbitrary session of 1, size of 0, spilling// is true | |||||
| ccbuilder.SetSessionId(1).SetCacheMemSz(4).SetSpill(true); | |||||
| std::shared_ptr<CacheClient> myClient; | |||||
| rc = ccbuilder.Build(&myClient); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| std::shared_ptr<CacheOp> myCacheOp; | std::shared_ptr<CacheOp> myCacheOp; | ||||
| rc = CacheOp::Builder().SetNumWorkers(4).SetClient(myClient).SetRowsPerBuffer(3).Build(&myCacheOp); | rc = CacheOp::Builder().SetNumWorkers(4).SetClient(myClient).SetRowsPerBuffer(3).Build(&myCacheOp); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rc = myTree->AssociateNode(myCacheOp); | rc = myTree->AssociateNode(myCacheOp); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| // RepeatOp | // RepeatOp | ||||
| uint32_t numRepeats = 4; | uint32_t numRepeats = 4; | ||||
| std::shared_ptr<RepeatOp> myRepeatOp; | std::shared_ptr<RepeatOp> myRepeatOp; | ||||
| rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp); | rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rc = myTree->AssociateNode(myRepeatOp); | rc = myTree->AssociateNode(myRepeatOp); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| // Assign tree relations and root | // Assign tree relations and root | ||||
| rc = myRepeatOp->AddChild(myCacheOp); | rc = myRepeatOp->AddChild(myCacheOp); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rc = myCacheOp->AddChild(myRandomDataOp); | rc = myCacheOp->AddChild(myRandomDataOp); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rc = myTree->AssignRoot(myRepeatOp); | rc = myTree->AssignRoot(myRepeatOp); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| MS_LOG(INFO) << "Launching tree and begin iteration"; | MS_LOG(INFO) << "Launching tree and begin iteration"; | ||||
| rc = myTree->Prepare(); | rc = myTree->Prepare(); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| std::cout << *myClient << std::endl; | std::cout << *myClient << std::endl; | ||||
| rc = myTree->Launch(); | rc = myTree->Launch(); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| // Start the loop of reading tensors from our pipeline | // Start the loop of reading tensors from our pipeline | ||||
| DatasetIterator dI(myTree); | DatasetIterator dI(myTree); | ||||
| TensorRow tensorList; | TensorRow tensorList; | ||||
| rc = dI.FetchNextTensorRow(&tensorList); | rc = dI.FetchNextTensorRow(&tensorList); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| int rowCount = 0; | int rowCount = 0; | ||||
| while (!tensorList.empty()) { | while (!tensorList.empty()) { | ||||
| // Don't display these rows, just count them | // Don't display these rows, just count them | ||||
| MS_LOG(INFO) << "Row fetched #: " << rowCount; | MS_LOG(INFO) << "Row fetched #: " << rowCount; | ||||
| rc = dI.FetchNextTensorRow(&tensorList); | rc = dI.FetchNextTensorRow(&tensorList); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rowCount++; | rowCount++; | ||||
| } | } | ||||
| ASSERT_EQ(rowCount, 40); | ASSERT_EQ(rowCount, 40); | ||||
| rc = myClient->DestroyCache(); | rc = myClient->DestroyCache(); | ||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| } | } | ||||
| @@ -15,6 +15,8 @@ | |||||
| """ | """ | ||||
| Testing cache operator with mappable datasets | Testing cache operator with mappable datasets | ||||
| """ | """ | ||||
| import os | |||||
| import pytest | |||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| import mindspore.dataset.transforms.vision.c_transforms as c_vision | import mindspore.dataset.transforms.vision.c_transforms as c_vision | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| @@ -25,6 +27,7 @@ DATA_DIR = "../data/dataset/testImageNetData/train/" | |||||
| GENERATE_GOLDEN = False | GENERATE_GOLDEN = False | ||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_map_basic1(): | def test_cache_map_basic1(): | ||||
| """ | """ | ||||
| Test mappable leaf with cache op right over the leaf | Test mappable leaf with cache op right over the leaf | ||||
| @@ -53,7 +56,7 @@ def test_cache_map_basic1(): | |||||
| logger.info("test_cache_map_basic1 Ended.\n") | logger.info("test_cache_map_basic1 Ended.\n") | ||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_map_basic2(): | def test_cache_map_basic2(): | ||||
| """ | """ | ||||
| Test mappable leaf with the cache op later in the tree above the map(decode) | Test mappable leaf with the cache op later in the tree above the map(decode) | ||||
| @@ -82,7 +85,7 @@ def test_cache_map_basic2(): | |||||
| logger.info("test_cache_map_basic2 Ended.\n") | logger.info("test_cache_map_basic2 Ended.\n") | ||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_map_basic3(): | def test_cache_map_basic3(): | ||||
| """ | """ | ||||
| Test a repeat under mappable cache | Test a repeat under mappable cache | ||||
| @@ -116,7 +119,7 @@ def test_cache_map_basic3(): | |||||
| assert num_iter == 8 | assert num_iter == 8 | ||||
| logger.info('test_cache_basic3 Ended.\n') | logger.info('test_cache_basic3 Ended.\n') | ||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_map_basic4(): | def test_cache_map_basic4(): | ||||
| """ | """ | ||||
| Test different rows result in core dump | Test different rows result in core dump | ||||
| @@ -141,7 +144,7 @@ def test_cache_map_basic4(): | |||||
| assert num_iter == 8 | assert num_iter == 8 | ||||
| logger.info('test_cache_basic3 Ended.\n') | logger.info('test_cache_basic3 Ended.\n') | ||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_map_failure1(): | def test_cache_map_failure1(): | ||||
| """ | """ | ||||
| Test nested cache (failure) | Test nested cache (failure) | ||||
| @@ -15,6 +15,8 @@ | |||||
| """ | """ | ||||
| Testing cache operator with non-mappable datasets | Testing cache operator with non-mappable datasets | ||||
| """ | """ | ||||
| import os | |||||
| import pytest | |||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| import mindspore.dataset.transforms.vision.c_transforms as c_vision | import mindspore.dataset.transforms.vision.c_transforms as c_vision | ||||
| @@ -25,6 +27,7 @@ SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | |||||
| GENERATE_GOLDEN = False | GENERATE_GOLDEN = False | ||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_nomap_basic1(): | def test_cache_nomap_basic1(): | ||||
| """ | """ | ||||
| A random dataset (a non mappable dataset) with a cache over it just after the leaf | A random dataset (a non mappable dataset) with a cache over it just after the leaf | ||||
| @@ -54,6 +57,7 @@ def test_cache_nomap_basic1(): | |||||
| logger.info("test_cache_nomap_basic1 Ended.\n") | logger.info("test_cache_nomap_basic1 Ended.\n") | ||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_nomap_basic2(): | def test_cache_nomap_basic2(): | ||||
| """ | """ | ||||
| A random dataset (a non mappable dataset) with a cache over it just after the leaf | A random dataset (a non mappable dataset) with a cache over it just after the leaf | ||||
| @@ -85,6 +89,7 @@ def test_cache_nomap_basic2(): | |||||
| logger.info("test_cache_nomap_basic2 Ended.\n") | logger.info("test_cache_nomap_basic2 Ended.\n") | ||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_nomap_basic3(): | def test_cache_nomap_basic3(): | ||||
| """ | """ | ||||
| A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf | A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf | ||||
| @@ -112,9 +117,21 @@ def test_cache_nomap_basic3(): | |||||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | logger.info("Number of data in ds1: {} ".format(num_iter)) | ||||
| assert num_iter == 12 | assert num_iter == 12 | ||||
| # Contact the server to get the statistics | |||||
| stat = some_cache.GetStat() | |||||
| cache_sz = stat.avg_cache_sz | |||||
| num_mem_cached = stat.num_mem_cached | |||||
| num_disk_cached = stat.num_disk_cached | |||||
| logger.info("Number of rows cached in memory: {}".format(num_mem_cached)) | |||||
| logger.info("Number of rows spilled to disk: {}".format(num_disk_cached)) | |||||
| logger.info("Average row cache size: {}".format(cache_sz)) | |||||
| logger.info("test_cache_nomap_basic3 Ended.\n") | logger.info("test_cache_nomap_basic3 Ended.\n") | ||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_nomap_basic4(): | def test_cache_nomap_basic4(): | ||||
| """ | """ | ||||
| A TF reader dataset (a non mappable dataset) with a map decode and cache after it | A TF reader dataset (a non mappable dataset) with a map decode and cache after it | ||||
| @@ -155,6 +172,7 @@ def test_cache_nomap_basic4(): | |||||
| logger.info("test_cache_nomap_basic4 Ended.\n") | logger.info("test_cache_nomap_basic4 Ended.\n") | ||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_nomap_basic5(): | def test_cache_nomap_basic5(): | ||||
| """ | """ | ||||
| A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf | A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf | ||||
| @@ -191,6 +209,7 @@ def test_cache_nomap_basic5(): | |||||
| logger.info("test_cache_nomap_basic5 Ended.\n") | logger.info("test_cache_nomap_basic5 Ended.\n") | ||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_nomap_basic6(): | def test_cache_nomap_basic6(): | ||||
| """ | """ | ||||
| A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf | A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf | ||||
| @@ -230,6 +249,7 @@ def test_cache_nomap_basic6(): | |||||
| logger.info("test_cache_nomap_basic6 Ended.\n") | logger.info("test_cache_nomap_basic6 Ended.\n") | ||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_nomap_basic7(): | def test_cache_nomap_basic7(): | ||||
| """ | """ | ||||
| A TF reader dataset (a non mappable dataset) that uses global shuffle, and is cached followed by | A TF reader dataset (a non mappable dataset) that uses global shuffle, and is cached followed by | ||||
| @@ -265,6 +285,7 @@ def test_cache_nomap_basic7(): | |||||
| logger.info("test_cache_nomap_basic7 Ended.\n") | logger.info("test_cache_nomap_basic7 Ended.\n") | ||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_nomap_allowed_share1(): | def test_cache_nomap_allowed_share1(): | ||||
| """ | """ | ||||
| It is allowed to share the cache between the following two trees: | It is allowed to share the cache between the following two trees: | ||||
| @@ -280,7 +301,7 @@ def test_cache_nomap_allowed_share1(): | |||||
| ds.config.set_seed(1) | ds.config.set_seed(1) | ||||
| # This dataset has 3 records in it only | # This dataset has 3 records in it only | ||||
| some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) | |||||
| some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True, prefetch_size=32) | |||||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache) | ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache) | ||||
| ds1 = ds1.repeat(4) | ds1 = ds1.repeat(4) | ||||
| @@ -300,6 +321,7 @@ def test_cache_nomap_allowed_share1(): | |||||
| logger.info("test_cache_nomap_allowed_share1 Ended.\n") | logger.info("test_cache_nomap_allowed_share1 Ended.\n") | ||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_nomap_allowed_share2(): | def test_cache_nomap_allowed_share2(): | ||||
| """ | """ | ||||
| It is allowed to share the cache between the following two trees (with map decode): | It is allowed to share the cache between the following two trees (with map decode): | ||||
| @@ -341,6 +363,7 @@ def test_cache_nomap_allowed_share2(): | |||||
| logger.info("test_cache_nomap_allowed_share2 Ended.\n") | logger.info("test_cache_nomap_allowed_share2 Ended.\n") | ||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_nomap_allowed_share3(): | def test_cache_nomap_allowed_share3(): | ||||
| """ | """ | ||||
| It is allowed to share the cache between the following two trees (different shard ids): | It is allowed to share the cache between the following two trees (different shard ids): | ||||
| @@ -376,6 +399,7 @@ def test_cache_nomap_allowed_share3(): | |||||
| logger.info("test_cache_nomap_allowed_share3 Ended.\n") | logger.info("test_cache_nomap_allowed_share3 Ended.\n") | ||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_nomap_allowed_share4(): | def test_cache_nomap_allowed_share4(): | ||||
| """ | """ | ||||
| It is allowed to share the cache between the following two trees: | It is allowed to share the cache between the following two trees: | ||||
| @@ -414,6 +438,7 @@ def test_cache_nomap_allowed_share4(): | |||||
| logger.info("test_cache_nomap_allowed_share4 Ended.\n") | logger.info("test_cache_nomap_allowed_share4 Ended.\n") | ||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_nomap_disallowed_share1(): | def test_cache_nomap_disallowed_share1(): | ||||
| """ | """ | ||||
| It is not allowed to share the cache between the following two trees: | It is not allowed to share the cache between the following two trees: | ||||