| @@ -24,6 +24,11 @@ if (ENABLE_TDTQUE) | |||
| add_definitions(-D ENABLE_TDTQUE) | |||
| message(STATUS "TDT queue is enabled") | |||
| endif () | |||
| if (MS_BUILD_GRPC) | |||
| set (ENABLE_CACHE true) | |||
| add_definitions(-D ENABLE_CACHE) | |||
| message(STATUS "Cache is enabled") | |||
| endif() | |||
| # conde coverage | |||
| # option(ENABLE_COVERAGE "Enable code coverage report" OFF) | |||
| @@ -47,10 +52,6 @@ include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ORIGIN:$ORIGIN/lib") | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=default") | |||
| include_directories("${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache") | |||
| set(MD_FLATBUFFER_OU "${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache") | |||
| ms_build_flatbuffers("engine/cache/de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${MD_FLATBUFFER_OU}) | |||
| ################## Include sub-modules ############################### | |||
| add_subdirectory(util) | |||
| add_subdirectory(core) | |||
| @@ -70,8 +71,6 @@ add_dependencies(engine-datasetops-source-sampler core) | |||
| add_dependencies(engine-datasetops core) | |||
| add_dependencies(engine-datasetops-mapop core) | |||
| add_dependencies(engine-opt core) | |||
| add_dependencies(engine-cache-client core) | |||
| add_dependencies(engine-cache-server core) | |||
| add_dependencies(engine-perf core) | |||
| add_dependencies(engine-gnn core) | |||
| add_dependencies(engine core) | |||
| @@ -85,7 +84,11 @@ endif() | |||
| if (ENABLE_TDTQUE) | |||
| add_dependencies(engine-tdt core) | |||
| endif () | |||
| if (ENABLE_CACHE) | |||
| add_dependencies(engine-datasetops engine-cache-client) | |||
| add_dependencies(engine-cache-client core) | |||
| add_dependencies(engine-cache-server core) | |||
| endif () | |||
| ################### Create _c_dataengine Library ###################### | |||
| set(submodules | |||
| $<TARGET_OBJECTS:core> | |||
| @@ -105,7 +108,6 @@ set(submodules | |||
| $<TARGET_OBJECTS:engine-datasetops> | |||
| $<TARGET_OBJECTS:engine-opt> | |||
| $<TARGET_OBJECTS:engine-cache-client> | |||
| $<TARGET_OBJECTS:engine-cache-server> | |||
| $<TARGET_OBJECTS:engine> | |||
| $<TARGET_OBJECTS:text> | |||
| $<TARGET_OBJECTS:text-kernels> | |||
| @@ -123,8 +125,6 @@ else () | |||
| add_library(_c_dataengine SHARED ${submodules}) | |||
| endif () | |||
| add_dependencies(_c_dataengine generated_engine_files) | |||
| if (ENABLE_PYTHON) | |||
| set_target_properties(_c_dataengine PROPERTIES | |||
| PREFIX "${PYTHON_MODULE_PREFIX}" | |||
| @@ -187,6 +187,6 @@ else() | |||
| endif () | |||
| endif() | |||
| if (NOT CMAKE_SYSTEM_NAME MATCHES "Windows") | |||
| if (MS_BUILD_GRPC) | |||
| target_link_libraries(_c_dataengine PRIVATE mindspore::grpc++) | |||
| endif() | |||
| endif() | |||
| @@ -22,7 +22,25 @@ namespace dataset { | |||
| PYBIND_REGISTER(CacheClient, 0, ([](const py::module *m) { | |||
| (void)py::class_<CacheClient, std::shared_ptr<CacheClient>>(*m, "CacheClient") | |||
| .def(py::init<uint32_t, uint64_t, bool>()); | |||
| .def( | |||
| py::init([](session_id_type id, uint64_t mem_sz, bool spill, int32_t port, int32_t prefetch_sz) { | |||
| std::shared_ptr<CacheClient> cc; | |||
| CacheClient::Builder builder; | |||
| builder.SetSessionId(id).SetCacheMemSz(mem_sz).SetSpill(spill).SetPort(port).SetPrefetchSize( | |||
| prefetch_sz); | |||
| THROW_IF_ERROR(builder.Build(&cc)); | |||
| return cc; | |||
| })) | |||
| .def("GetStat", [](CacheClient &cc) { | |||
| CacheServiceStat stat{}; | |||
| THROW_IF_ERROR(cc.GetStat(&stat)); | |||
| return stat; | |||
| }); | |||
| (void)py::class_<CacheServiceStat>(*m, "CacheServiceStat") | |||
| .def(py::init<>()) | |||
| .def_readwrite("avg_cache_sz", &CacheServiceStat::avg_cache_sz) | |||
| .def_readwrite("num_mem_cached", &CacheServiceStat::num_mem_cached) | |||
| .def_readwrite("num_disk_cached", &CacheServiceStat::num_disk_cached); | |||
| })); | |||
| } // namespace dataset | |||
| @@ -72,7 +72,8 @@ constexpr uint32_t kCfgMonitorSamplingInterval = 10; | |||
| // Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h) | |||
| constexpr uint8_t kCVInvalidType = 255; | |||
| using connection_id_type = int64_t; | |||
| using connection_id_type = uint64_t; | |||
| using session_id_type = uint32_t; | |||
| using row_id_type = int64_t; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -20,10 +20,8 @@ if (ENABLE_PYTHON) | |||
| target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS}) | |||
| endif() | |||
| add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf engine-cache-client engine-datasetops-mapop) | |||
| if (ENABLE_TDTQUE) | |||
| add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt engine-gnn engine-perf | |||
| engine-cache-client engine-cache-server engine-datasetops-mapop) | |||
| else () | |||
| add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf | |||
| engine-cache-client engine-cache-server engine-datasetops-mapop) | |||
| add_dependencies(engine engine-tdt) | |||
| endif () | |||
| @@ -1,8 +1,47 @@ | |||
| include_directories("${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache") | |||
| set(MD_FLATBUFFER_OU "${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache") | |||
| ms_build_flatbuffers("de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${MD_FLATBUFFER_OU}) | |||
| file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | |||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | |||
| add_library(engine-cache-client OBJECT | |||
| cache_client.cc | |||
| cache_fbb.cc | |||
| cache_request.cc) | |||
| add_library(engine-cache-server OBJECT | |||
| cache_service.cc | |||
| cache_server.cc) | |||
| if (ENABLE_CACHE) | |||
| ms_grpc_generate(CACHE_GRPC_SRCS CACHE_GRPC_HDRS cache_grpc.proto) | |||
| target_sources(engine-cache-client PUBLIC ${CACHE_GRPC_SRCS} cache_grpc_client.cc) | |||
| add_library(engine-cache-server OBJECT | |||
| ${CACHE_GRPC_SRCS} | |||
| cache_grpc_server.cc | |||
| cache_arena.cc | |||
| cache_service.cc | |||
| cache_server.cc) | |||
| add_executable(cache_server cache_main.cc) | |||
| target_link_libraries(cache_server | |||
| engine-cache-server | |||
| $<TARGET_OBJECTS:utils> | |||
| mindspore | |||
| mindspore::glog | |||
| mindspore::protobuf | |||
| mindspore::grpc++ | |||
| mindspore_gvar | |||
| ${PYTHON_LIBRARIES} | |||
| ${SECUREC_LIBRARY} | |||
| pthread) | |||
| add_executable(cache_admin cache_admin.cc cache_admin_arg.cc) | |||
| target_link_libraries(cache_admin _c_dataengine _c_mindrecord ${PYTHON_LIBRARIES} mindspore::glog) | |||
| add_dependencies(engine-cache-server generated_engine_files) | |||
| else () | |||
| ms_protobuf_generate(CACHE_PROTO_SRCS CACHE_PRTO_HDRS cache_grpc.proto) | |||
| target_sources(engine-cache-client PUBLIC ${CACHE_PROTO_SRCS}) | |||
| endif () | |||
| add_dependencies(engine-cache-client generated_engine_files) | |||
| @@ -0,0 +1,70 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <unistd.h> | |||
| #include <iostream> | |||
| #ifdef USE_GLOG | |||
| #include <glog/logging.h> | |||
| #endif | |||
| #include "minddata/dataset/engine/cache/cache_admin_arg.h" | |||
| namespace ds = mindspore::dataset; | |||
| int main(int argc, char **argv) { | |||
| ds::Status rc; | |||
| ds::CacheAdminArgHandler args; | |||
| std::stringstream arg_stream; | |||
| #ifdef USE_GLOG | |||
| FLAGS_log_dir = "/tmp"; | |||
| google::InitGoogleLogging(argv[0]); | |||
| #endif | |||
| std::string warningMsg; | |||
| warningMsg.reserve(512); | |||
| warningMsg += "WARNING:\n"; | |||
| warningMsg += "cache_admin and the cache server that it controls are currently only used for experimental research"; | |||
| warningMsg += " purposes at this time.\n"; | |||
| warningMsg += "It is not intended for general availability yet as it may not be stable. Use it at your own risk.\n"; | |||
| // A warning message until the code is mature enough. | |||
| std::cerr << warningMsg << std::endl; | |||
| if (argc == 1) { | |||
| args.Help(); | |||
| return 0; | |||
| } | |||
| // ingest all the args into a string stream for parsing | |||
| for (int i = 1; i < argc; ++i) { | |||
| arg_stream << " " << std::string(argv[i]); | |||
| } | |||
| // Parse the args | |||
| rc = args.ParseArgStream(&arg_stream); | |||
| if (!rc.IsOk()) { | |||
| std::cerr << rc.ToString() << std::endl; | |||
| return 1; | |||
| } | |||
| // Execute the command | |||
| rc = args.RunCommand(); | |||
| if (!rc.IsOk()) { | |||
| std::cerr << rc.ToString() << std::endl; | |||
| return 1; | |||
| } | |||
| return 0; | |||
| } | |||
| @@ -0,0 +1,396 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/cache/cache_admin_arg.h" | |||
| #include <sys/types.h> | |||
| #include <sys/stat.h> | |||
| #include <sys/wait.h> | |||
| #include <unistd.h> | |||
| #include <cerrno> | |||
| #include <iostream> | |||
| #include <string> | |||
| #include <cstdlib> | |||
| #include "minddata/dataset/engine/cache/cache_request.h" | |||
| #include "minddata/dataset/engine/cache/cache_client.h" | |||
| #include "minddata/dataset/util/path.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| const char CacheAdminArgHandler::kDefaultHost[] = "127.0.0.1"; | |||
| const char CacheAdminArgHandler::kServerBinary[] = "cache_server"; | |||
| const char CacheAdminArgHandler::kDefaultSpillDir[] = "/tmp"; | |||
| CacheAdminArgHandler::CacheAdminArgHandler() | |||
| : port_(kDefaultPort), | |||
| session_id_(0), | |||
| num_workers_(kDefaultNumWorkers), | |||
| shm_mem_sz_(kDefaultSharedMemorySizeInGB), | |||
| log_level_(kDefaultLogLevel), | |||
| hostname_(kDefaultHost), | |||
| spill_dir_(kDefaultSpillDir), | |||
| command_id_(CommandId::kCmdUnknown) { | |||
| // Initialize the command mappings | |||
| arg_map_["-h"] = ArgValue::kArgHost; | |||
| arg_map_["--hostname"] = ArgValue::kArgHost; | |||
| arg_map_["-p"] = ArgValue::kArgPort; | |||
| arg_map_["--port"] = ArgValue::kArgPort; | |||
| arg_map_["--start"] = ArgValue::kArgStart; | |||
| arg_map_["--stop"] = ArgValue::kArgStop; | |||
| arg_map_["--help"] = ArgValue::kArgHelp; | |||
| arg_map_["--generate_session"] = ArgValue::kArgGenerateSession; | |||
| arg_map_["-g"] = ArgValue::kArgGenerateSession; | |||
| arg_map_["--destroy_session"] = ArgValue::kArgDestroySession; | |||
| arg_map_["-d"] = ArgValue::kArgDestroySession; | |||
| arg_map_["--spilldir"] = ArgValue::kArgSpillDir; | |||
| arg_map_["-s"] = ArgValue::kArgSpillDir; | |||
| arg_map_["-w"] = ArgValue::kArgNumWorkers; | |||
| arg_map_["--workers"] = ArgValue::kArgNumWorkers; | |||
| arg_map_["-m"] = ArgValue::kArgSharedMemorySize; | |||
| arg_map_["--shared_memory_size"] = ArgValue::kArgSharedMemorySize; | |||
| arg_map_["-l"] = ArgValue::kArgLogLevel; | |||
| arg_map_["--minloglevel"] = ArgValue::kArgLogLevel; | |||
| // Initialize argument tracker with false values | |||
| for (int16_t i = 0; i < static_cast<int16_t>(ArgValue::kArgNumArgs); ++i) { | |||
| ArgValue currAV = static_cast<ArgValue>(i); | |||
| used_args_[currAV] = false; | |||
| } | |||
| } | |||
| Status CacheAdminArgHandler::AssignArg(std::string option, int32_t *out_arg, std::stringstream *arg_stream, | |||
| CommandId command_id) { | |||
| // Detect if the user tried to provide this argument more than once | |||
| ArgValue selected_arg = arg_map_[option]; | |||
| if (used_args_[selected_arg]) { | |||
| std::string err_msg = "The " + option + " argument was given more than once."; | |||
| return Status(StatusCode::kSyntaxError, err_msg); | |||
| } | |||
| // Flag that this arg is used now | |||
| used_args_[selected_arg] = true; | |||
| // Some options are just arguments, for example "--port 50052" is not a command, it's just a argument. | |||
| // Other options are actual commands, for example "--destroy_session 1234". This executes the destroy session. | |||
| // If this option is also a command, make sure there has not been multiple commands given before assigning it. | |||
| if (command_id != CommandId::kCmdUnknown) { | |||
| if (command_id_ != CommandId::kCmdUnknown) { | |||
| std::string err_msg = "Only one command at a time is allowed. Invalid command: " + option; | |||
| return Status(StatusCode::kSyntaxError, err_msg); | |||
| } else { | |||
| command_id_ = command_id; | |||
| } | |||
| } | |||
| std::string value_as_string; | |||
| // Fetch the argument from the arg stream into a string | |||
| *arg_stream >> value_as_string; | |||
| if (value_as_string.empty()) { | |||
| std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>"; | |||
| return Status(StatusCode::kSyntaxError, err_msg); | |||
| } | |||
| // Now, attempt to convert the value into it's string format for output | |||
| try { | |||
| *out_arg = std::stoul(value_as_string); | |||
| } catch (const std::exception &e) { | |||
| std::string err_msg = "Invalid numeric value: " + value_as_string; | |||
| return Status(StatusCode::kSyntaxError, err_msg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheAdminArgHandler::AssignArg(std::string option, std::string *out_arg, std::stringstream *arg_stream, | |||
| CommandId command_id) { | |||
| // Detect if the user tried to provide this argument more than once | |||
| ArgValue selected_arg = arg_map_[option]; | |||
| if (used_args_[selected_arg]) { | |||
| std::string err_msg = "The " + option + " argument was given more than once."; | |||
| return Status(StatusCode::kSyntaxError, err_msg); | |||
| } | |||
| // Flag that this arg is used now | |||
| used_args_[selected_arg] = true; | |||
| // Some options are just arguments, for example "--hostname "127.0.0.1" is not a command, it's just an argument. | |||
| // Other options are actual commands, for example "--start". | |||
| // If this option is also a command, make sure there has not been multiple commands given before assigning it. | |||
| if (command_id != CommandId::kCmdUnknown) { | |||
| if (command_id_ != CommandId::kCmdUnknown) { | |||
| std::string err_msg = "Only one command at a time is allowed. Invalid command: " + option; | |||
| return Status(StatusCode::kSyntaxError, err_msg); | |||
| } else { | |||
| command_id_ = command_id; | |||
| } | |||
| } | |||
| // If there is no argument to get, such as the --start command, then out_arg will be a nullptr. | |||
| if (out_arg != nullptr) { | |||
| // Fetch the argument from the arg stream into a string | |||
| *arg_stream >> *out_arg; | |||
| if (out_arg->empty()) { | |||
| std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>"; | |||
| return Status(StatusCode::kSyntaxError, err_msg); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheAdminArgHandler::ParseArgStream(std::stringstream *arg_stream) { | |||
| std::string tok; | |||
| while (*arg_stream >> tok) { | |||
| switch (arg_map_[tok]) { | |||
| case ArgValue::kArgHost: { | |||
| RETURN_IF_NOT_OK(AssignArg(tok, &hostname_, arg_stream)); | |||
| break; | |||
| } | |||
| case ArgValue::kArgPort: { | |||
| RETURN_IF_NOT_OK(AssignArg(tok, &port_, arg_stream)); | |||
| break; | |||
| } | |||
| case ArgValue::kArgStart: { | |||
| RETURN_IF_NOT_OK(AssignArg(tok, static_cast<std::string *>(nullptr), arg_stream, CommandId::kCmdStart)); | |||
| break; | |||
| } | |||
| case ArgValue::kArgStop: { | |||
| RETURN_IF_NOT_OK(AssignArg(tok, static_cast<std::string *>(nullptr), arg_stream, CommandId::kCmdStop)); | |||
| break; | |||
| } | |||
| case ArgValue::kArgGenerateSession: { | |||
| RETURN_IF_NOT_OK( | |||
| AssignArg(tok, static_cast<std::string *>(nullptr), arg_stream, CommandId::kCmdGenerateSession)); | |||
| break; | |||
| } | |||
| case ArgValue::kArgHelp: { | |||
| command_id_ = CommandId::kCmdHelp; | |||
| break; | |||
| } | |||
| case ArgValue::kArgDestroySession: { | |||
| // session_id is an unsigned type. We may need to template the AssignArg function so that | |||
| // it can handle different flavours of integers instead of just int32_t. | |||
| int32_t session_int; | |||
| RETURN_IF_NOT_OK(AssignArg(tok, &session_int, arg_stream, CommandId::kCmdDestroySession)); | |||
| session_id_ = session_int; | |||
| break; | |||
| } | |||
| case ArgValue::kArgNumWorkers: { | |||
| RETURN_IF_NOT_OK(AssignArg(tok, &num_workers_, arg_stream)); | |||
| break; | |||
| } | |||
| case ArgValue::kArgSpillDir: { | |||
| RETURN_IF_NOT_OK(AssignArg(tok, &spill_dir_, arg_stream)); | |||
| break; | |||
| } | |||
| case ArgValue::kArgSharedMemorySize: { | |||
| RETURN_IF_NOT_OK(AssignArg(tok, &shm_mem_sz_, arg_stream)); | |||
| break; | |||
| } | |||
| case ArgValue::kArgLogLevel: { | |||
| RETURN_IF_NOT_OK(AssignArg(tok, &log_level_, arg_stream)); | |||
| break; | |||
| } | |||
| default: { | |||
| // Save space delimited trailing arguments | |||
| trailing_args_ += (" " + tok); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| RETURN_IF_NOT_OK(Validate()); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheAdminArgHandler::Validate() { | |||
| // This sanity check is delayed until now in case there may be valid use-cases of trailing args. | |||
| // Any unhandled arguments at this point is an error. | |||
| if (!trailing_args_.empty()) { | |||
| std::string err_msg = "Invalid arguments provided: " + trailing_args_; | |||
| return Status(StatusCode::kSyntaxError, err_msg); | |||
| } | |||
| // The user must pick at least one command. i.e. it's meaningless to just give a hostname or port but no command to | |||
| // run. | |||
| if (command_id_ == CommandId::kCmdUnknown) { | |||
| std::string err_msg = "No command provided"; | |||
| return Status(StatusCode::kSyntaxError, err_msg); | |||
| } | |||
| // Additional checks here | |||
| if (num_workers_ < 1) return Status(StatusCode::kSyntaxError, "Number of workers must be positive value."); | |||
| if (log_level_ < 0 || log_level_ > 3) return Status(StatusCode::kSyntaxError, "Log level must be in range (0..3)."); | |||
| // port range check? | |||
| return Status::OK(); | |||
| } | |||
| Status CacheAdminArgHandler::RunCommand() { | |||
| switch (command_id_) { | |||
| case CommandId::kCmdHelp: { | |||
| Help(); | |||
| break; | |||
| } | |||
| case CommandId::kCmdStart: { | |||
| RETURN_IF_NOT_OK(StartServer()); | |||
| break; | |||
| } | |||
| case CommandId::kCmdStop: { | |||
| RETURN_IF_NOT_OK(StopServer()); | |||
| break; | |||
| } | |||
| case CommandId::kCmdGenerateSession: { | |||
| CacheClientGreeter comm(hostname_, port_, 1); | |||
| RETURN_IF_NOT_OK(comm.ServiceStart()); | |||
| auto rq = std::make_shared<GenerateSessionIdRequest>(); | |||
| RETURN_IF_NOT_OK(comm.HandleRequest(rq)); | |||
| RETURN_IF_NOT_OK(rq->Wait()); | |||
| std::cout << rq->GetSessionId() << std::endl; | |||
| break; | |||
| } | |||
| case CommandId::kCmdDestroySession: { | |||
| CacheClientGreeter comm(hostname_, port_, 1); | |||
| RETURN_IF_NOT_OK(comm.ServiceStart()); | |||
| CacheClientInfo cinfo; | |||
| cinfo.set_session_id(session_id_); | |||
| auto rq = std::make_shared<DropSessionRequest>(cinfo); | |||
| RETURN_IF_NOT_OK(comm.HandleRequest(rq)); | |||
| RETURN_IF_NOT_OK(rq->Wait()); | |||
| std::cout << "Drop session successful" << std::endl; | |||
| break; | |||
| } | |||
| default: { | |||
| RETURN_STATUS_UNEXPECTED("Invalid cache admin command id."); | |||
| break; | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheAdminArgHandler::StartServer() { | |||
| // There currently does not exist any "install path" or method to identify which path the installed binaries will | |||
| // exist in. As a temporary approach, we will assume that the server binary shall exist in the same path as the | |||
| // cache_admin binary (this process). | |||
| const std::string self_proc = "/proc/self/exe"; | |||
| std::string canonical_path; | |||
| canonical_path.resize(400); // PATH_MAX is large. This value should be big enough for our use. | |||
| // Some lower level OS library calls are needed here to determine the binary path. | |||
| // Fetch the path of this binary for admin_cache into C character array and then truncate off the binary name so that | |||
| // we are left with only the absolute path | |||
| if (realpath(self_proc.data(), canonical_path.data()) == nullptr) { | |||
| std::string err_msg = "Failed to identify cache admin binary path: " + std::to_string(errno); | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| canonical_path.resize(strlen(canonical_path.data())); | |||
| int last_seperator = canonical_path.find_last_of('/'); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(last_seperator != std::string::npos, "No / found"); | |||
| // truncate the binary name so we are left with the absolute path of cache_admin binary | |||
| canonical_path.resize(last_seperator + 1); | |||
| std::string cache_server_binary = canonical_path + std::string(kServerBinary); | |||
| // Create a pipe before we fork. If all goes well, the child will run as a daemon in the background | |||
| // and never returns until shutdown. If there is any error, the child will notify us through the pipe. | |||
| int fd[2]; | |||
| if (pipe(fd) == -1) { | |||
| std::string err_msg = "Failed to create a pipe for communication " + std::to_string(errno); | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| // fork the child process to become the daemon | |||
| pid_t pid; | |||
| pid = fork(); | |||
| // failed to fork | |||
| if (pid < 0) { | |||
| std::string err_msg = "Failed to fork process for cache server: " + std::to_string(errno); | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } else if (pid > 0) { | |||
| // As a parent, we close the write end. We only listen. | |||
| close(fd[1]); | |||
| dup2(fd[0], 0); | |||
| close(fd[0]); | |||
| wait(nullptr); | |||
| std::string msg; | |||
| const int32_t buf_sz = 1024; | |||
| msg.resize(buf_sz); | |||
| auto n = read(0, msg.data(), buf_sz); | |||
| if (n < 0) { | |||
| std::string err_msg = "Failed to read from pipeline " + std::to_string(errno); | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| msg.resize(n); | |||
| std::cout << msg << std::endl; | |||
| return Status::OK(); | |||
| } else { | |||
| // Child here ... | |||
| // Close all stdin, redirect stdout and stderr to the write end of the pipe. | |||
| close(fd[0]); | |||
| dup2(fd[1], 1); | |||
| dup2(fd[1], 2); | |||
| close(0); | |||
| close(fd[1]); | |||
| // exec the cache server binary in this process | |||
| std::string port_string = std::to_string(port_); | |||
| std::string workers_string = std::to_string(num_workers_); | |||
| std::string shared_memory_string = std::to_string(shm_mem_sz_); | |||
| std::string minloglevel_string = std::to_string(log_level_); | |||
| std::string daemonize_string = "true"; | |||
| char *argv[8]; | |||
| argv[0] = cache_server_binary.data(); // First arg is usually the binary name | |||
| argv[1] = spill_dir_.data(); | |||
| argv[2] = workers_string.data(); | |||
| argv[3] = port_string.data(); | |||
| argv[4] = shared_memory_string.data(); | |||
| argv[5] = minloglevel_string.data(); | |||
| argv[6] = daemonize_string.data(); | |||
| argv[7] = nullptr; | |||
| // Now exec the binary | |||
| execv(argv[0], argv); | |||
| // If the exec was successful, this line will never be reached due to process image being replaced. | |||
| // ..unless exec failed. | |||
| std::string err_msg = "Failed to exec cache server: " + cache_server_binary; | |||
| std::cerr << err_msg << std::endl; | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| } | |||
| Status CacheAdminArgHandler::StopServer() { | |||
| CacheClientGreeter comm(hostname_, port_, 1); | |||
| RETURN_IF_NOT_OK(comm.ServiceStart()); | |||
| auto rq = std::make_shared<ShutdownRequest>(); | |||
| RETURN_IF_NOT_OK(comm.HandleRequest(rq)); | |||
| return Status::OK(); | |||
| } | |||
| void CacheAdminArgHandler::Help() { | |||
| std::cerr << "Syntax:\n"; | |||
| std::cerr << " cache_admin [--start | --stop]\n"; | |||
| std::cerr << " [ [-h | --hostname] <hostname> ]\n"; | |||
| std::cerr << " [ [-p | --port] <port number> ]\n"; | |||
| std::cerr << " [ [-g | --generate_session] ]\n"; | |||
| std::cerr << " [ [-d | --destroy_session] <session id> ]\n"; | |||
| std::cerr << " [ [-w | --workers] <number of workers> ]\n"; | |||
| std::cerr << " [ [-s | --spilldir] <spilling directory> ]\n"; | |||
| std::cerr << " [ [-m | --shared_memory_size] <shared memory size> ]\n"; | |||
| std::cerr << " [ [-l | --minloglevel] <log level> ]\n"; | |||
| std::cerr << " [--help]" << std::endl; | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,105 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ADMIN_ARG_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ADMIN_ARG_H_ | |||
| #include <iostream> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <sstream> | |||
| #include "minddata/dataset/util/status.h" | |||
| #include "minddata/dataset/engine/cache/cache_client.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class CacheAdminArgHandler { | |||
| public: | |||
| static constexpr int32_t kDefaultPort = 50052; | |||
| static constexpr int32_t kDefaultNumWorkers = 32; | |||
| static constexpr int32_t kDefaultSharedMemorySizeInGB = 4; | |||
| static constexpr int32_t kDefaultLogLevel = 1; | |||
| static const char kDefaultHost[]; | |||
| static const char kServerBinary[]; | |||
| static const char kDefaultSpillDir[]; | |||
| // These are the actual command types to execute | |||
| enum class CommandId : int16_t { | |||
| kCmdHelp = 0, | |||
| kCmdStart = 1, | |||
| kCmdStop = 2, | |||
| kCmdGenerateSession = 3, | |||
| kCmdDestroySession = 4, | |||
| kCmdUnknown = 32767 | |||
| }; | |||
| CacheAdminArgHandler(); | |||
| ~CacheAdminArgHandler() = default; | |||
| Status ParseArgStream(std::stringstream *arg_stream); | |||
| Status RunCommand(); | |||
| void Help(); | |||
| private: | |||
| // These are the supported argument string integer mappings | |||
| enum class ArgValue : int16_t { | |||
| kArgUnknown = 0, // Must be at position 0. invalid map lookups in arg_map_ default to value 0 | |||
| kArgStart = 1, | |||
| kArgStop = 2, | |||
| kArgHost = 3, | |||
| kArgPort = 4, | |||
| kArgHelp = 5, | |||
| kArgGenerateSession = 6, | |||
| kArgDestroySession = 7, | |||
| kArgSpillDir = 8, | |||
| kArgNumWorkers = 9, | |||
| kArgSharedMemorySize = 10, | |||
| kArgLogLevel = 11, | |||
| kArgNumArgs = 12 // Must be the last position to provide a count | |||
| }; | |||
| Status StartServer(); | |||
| Status StopServer(); | |||
| Status AssignArg(std::string option, int32_t *out_arg, std::stringstream *arg_stream, | |||
| CommandId command_id = CommandId::kCmdUnknown); | |||
| Status AssignArg(std::string option, std::string *out_arg, std::stringstream *arg_stream, | |||
| CommandId command_id = CommandId::kCmdUnknown); | |||
| Status Validate(); | |||
| CommandId command_id_; | |||
| int32_t port_; | |||
| int32_t num_workers_; | |||
| int32_t shm_mem_sz_; | |||
| int32_t log_level_; | |||
| session_id_type session_id_; | |||
| std::string hostname_; | |||
| std::string spill_dir_; | |||
| std::string trailing_args_; | |||
| std::map<std::string, ArgValue> arg_map_; | |||
| std::map<ArgValue, bool> used_args_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ADMIN_ARG_H_ | |||
| @@ -0,0 +1,73 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/cache/cache_arena.h" | |||
| #include "minddata/dataset/util/path.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| CachedSharedMemoryArena::CachedSharedMemoryArena(int32_t port, size_t val_in_GB) | |||
| : Arena::Arena(val_in_GB * 1024), port_(port), shmid_(-1) {} | |||
| CachedSharedMemoryArena::~CachedSharedMemoryArena() { | |||
| #if CACHE_LOCAL_CLIENT | |||
| if (this->ptr_ != nullptr && this->ptr_ != reinterpret_cast<void *>(-1)) { | |||
| shmdt(this->ptr_); | |||
| } | |||
| this->ptr_ = nullptr; | |||
| if (shmid_ != -1) { | |||
| shmctl(shmid_, IPC_RMID, nullptr); | |||
| // Also remove the path we use to generate ftok. | |||
| Path p(PortToUnixSocketPath(port_)); | |||
| (void)p.Remove(); | |||
| } | |||
| #endif | |||
| } | |||
| Status CachedSharedMemoryArena::CreateArena(std::unique_ptr<CachedSharedMemoryArena> *out, int32_t port, | |||
| size_t val_in_GB) { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| #if CACHE_LOCAL_CLIENT | |||
| auto ba = new (std::nothrow) CachedSharedMemoryArena(port, val_in_GB); | |||
| if (ba == nullptr) { | |||
| return Status(StatusCode::kOutOfMemory); | |||
| } | |||
| // Transfer the ownership of this pointer. Any future error in the processing we will have | |||
| // the destructor of *out to deal. | |||
| (*out).reset(ba); | |||
| // Generate the ftok using a combination of port. | |||
| int err; | |||
| auto shm_key = PortToFtok(port, &err); | |||
| if (shm_key == (key_t)-1) { | |||
| std::string errMsg = "Ftok failed with errno " + std::to_string(err); | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| auto access_mode = S_IRUSR | S_IWUSR | S_IROTH | S_IWOTH | S_IRGRP | S_IWGRP; | |||
| ba->shmid_ = shmget(shm_key, ba->size_in_bytes_, IPC_CREAT | IPC_EXCL | access_mode); | |||
| if (ba->shmid_) { | |||
| ba->ptr_ = shmat(ba->shmid_, nullptr, 0); | |||
| if (ba->ptr_ == reinterpret_cast<void *>(-1)) { | |||
| RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno)); | |||
| } | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Shared memory creation failed. Errno " + std::to_string(errno)); | |||
| } | |||
| uint64_t num_blks = ba->size_in_bytes_ / ARENA_BLK_SZ; | |||
| MS_LOG(DEBUG) << "Size of memory pool is " << num_blks << ", number of blocks of size is " << ARENA_BLK_SZ << "."; | |||
| ba->tr_.Insert(0, num_blks); | |||
| #endif | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,52 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ARENA_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ARENA_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include "minddata/dataset/util/arena.h" | |||
| #include "minddata/dataset/engine/cache/cache_common.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// This is a derived class of Arena but resides in shared memory | |||
| class CachedSharedMemoryArena : public Arena { | |||
| public: | |||
| ~CachedSharedMemoryArena() override; | |||
| /// \brief Create an Arena in shared memory | |||
| /// \param[out] p_ba Pointer to a unique_ptr | |||
| /// \param shmkey Shared memory key | |||
| /// \param val_in_GB size of shared memory in gigabyte | |||
| /// \return Status object | |||
| static Status CreateArena(std::unique_ptr<CachedSharedMemoryArena> *out, int32_t port, size_t val_in_GB); | |||
| /// \brief This returns where we attach to the shared memory. | |||
| /// Some gRPC requests will ask for a shared memory block, and | |||
| /// we can't return the absolute address as this makes no sense | |||
| /// in the client. So instead we will return an address relative | |||
| /// to the base address of the shared memory where we attach to. | |||
| /// \return Base address of the shared memory. | |||
| const void *SharedMemoryBaseAddr() const { return this->ptr_; } | |||
| private: | |||
| int32_t port_; | |||
| int shmid_; | |||
| /// Private constructor. Not to be called directly. | |||
| CachedSharedMemoryArena(int32_t port, size_t val_in_GB); | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ARENA_H_ | |||
| @@ -17,29 +17,45 @@ | |||
| #include <iomanip> | |||
| #include "minddata/dataset/engine/cache/cache_client.h" | |||
| #include "minddata/dataset/engine/cache/cache_request.h" | |||
| #include "minddata/dataset/engine/cache/cache_service.h" | |||
| #include "minddata/dataset/engine/cache/cache_fbb.h" | |||
| #include "minddata/dataset/util/bit.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Constructor | |||
| CacheClient::CacheClient(uint32_t session_id, uint64_t cache_mem_sz, bool spill) | |||
| : server_connection_id_(0), session_id_(session_id), cache_crc_(0), cache_mem_sz_(cache_mem_sz), spill_(spill) {} | |||
| CacheClient::CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool spill, std::string hostname, | |||
| int32_t port, int32_t num_workers, int32_t prefetch_size) | |||
| : server_connection_id_(0), | |||
| cache_mem_sz_(cache_mem_sz), | |||
| spill_(spill), | |||
| local_bypass_(false), | |||
| hostname_(std::move(hostname)), | |||
| port_(port), | |||
| num_workers_(num_workers), | |||
| prefetch_size_(prefetch_size) { | |||
| cinfo_.set_session_id(session_id); | |||
| comm_ = std::make_shared<CacheClientGreeter>(hostname_, port_, num_workers_); | |||
| } | |||
| // print method for display cache details | |||
| void CacheClient::Print(std::ostream &out) const { | |||
| out << " Session id: " << session_id_ << "\n Cache crc: " << cache_crc_ | |||
| << "\n Server cache id: " << server_connection_id_ << "\n Cache mem size: " << cache_mem_sz_ | |||
| << "\n Spilling: " << std::boolalpha << spill_; | |||
| out << " Session id: " << session_id() << "\n Cache crc: " << cinfo_.crc() | |||
| << "\n Server cache id: " << server_connection_id_ << "\n Cache mem size: " << getCacheMemSz() | |||
| << "\n Spilling: " << std::boolalpha << isSpill() << "\n Hostname: " << getHostname() | |||
| << "\n Port: " << getPort() << "\n Number of rpc workers: " << getNumWorkers() | |||
| << "\n Prefetch size: " << getPrefetchSize() << "\n Local client support: " << std::boolalpha | |||
| << SupportLocalClient(); | |||
| } | |||
| Status CacheClient::WriteRow(const TensorRow &row, row_id_type *row_id_from_server) const { | |||
| CacheRowRequest rq(server_connection_id_, cookie()); | |||
| RETURN_IF_NOT_OK(rq.SerializeCacheRowRequest(row)); | |||
| RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); | |||
| RETURN_IF_NOT_OK(rq.Wait()); | |||
| auto rq = std::make_shared<CacheRowRequest>(server_connection_id_, cookie(), SupportLocalClient()); | |||
| RETURN_IF_NOT_OK(rq->SerializeCacheRowRequest(this, row)); | |||
| RETURN_IF_NOT_OK(PushRequest(rq)); | |||
| RETURN_IF_NOT_OK(rq->Wait()); | |||
| if (row_id_from_server != nullptr) { | |||
| *row_id_from_server = rq.GetRowIdAfterCache(); | |||
| *row_id_from_server = rq->GetRowIdAfterCache(); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -47,29 +63,19 @@ Status CacheClient::WriteRow(const TensorRow &row, row_id_type *row_id_from_serv | |||
| Status CacheClient::WriteBuffer(std::unique_ptr<DataBuffer> &&in) const { | |||
| std::unique_ptr<DataBuffer> db_ptr = std::move(in); | |||
| auto num_rows = db_ptr->NumRows(); | |||
| std::vector<TensorRow> all_rows; | |||
| // We will send the requests async first on all rows and do a final wait. | |||
| if (num_rows > 0) { | |||
| all_rows.reserve(num_rows); | |||
| // Break down the DataBuffer into TensorRow. We will send the requests async | |||
| // and then do a final wait. | |||
| MemGuard<CacheRowRequest> rq_arr; | |||
| RETURN_IF_NOT_OK(rq_arr.allocate(num_rows, server_connection_id_, cookie())); | |||
| CacheServer &cs = CacheServer::GetInstance(); | |||
| auto arr = std::make_unique<std::shared_ptr<CacheRowRequest>[]>(num_rows); | |||
| for (auto i = 0; i < num_rows; ++i) { | |||
| TensorRow row; | |||
| auto rq = rq_arr[i]; | |||
| RETURN_IF_NOT_OK(db_ptr->PopRow(&row)); | |||
| RETURN_IF_NOT_OK(rq->SerializeCacheRowRequest(row)); | |||
| RETURN_IF_NOT_OK(cs.PushRequest(rq)); | |||
| // We can't let row go out of scope. Otherwise it will free all the tensor memory. | |||
| // So park it in the vector. When this function go out of scope, its memory | |||
| // will be freed. | |||
| all_rows.push_back(std::move(row)); | |||
| arr[i] = std::make_shared<CacheRowRequest>(server_connection_id_, cookie(), SupportLocalClient()); | |||
| RETURN_IF_NOT_OK(arr[i]->SerializeCacheRowRequest(this, row)); | |||
| RETURN_IF_NOT_OK(PushRequest(arr[i])); | |||
| } | |||
| // Now we wait for the requests to be done. | |||
| // Now we wait for them to come back | |||
| for (auto i = 0; i < num_rows; ++i) { | |||
| auto rq = rq_arr[i]; | |||
| RETURN_IF_NOT_OK(rq->Wait()); | |||
| RETURN_IF_NOT_OK(arr[i]->Wait()); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| @@ -77,11 +83,21 @@ Status CacheClient::WriteBuffer(std::unique_ptr<DataBuffer> &&in) const { | |||
| Status CacheClient::GetRows(const std::vector<row_id_type> &row_id, TensorTable *out) const { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| BatchFetchRequest rq(server_connection_id_, row_id); | |||
| RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); | |||
| RETURN_IF_NOT_OK(rq.Wait()); | |||
| RETURN_IF_NOT_OK(rq.RestoreRows(out)); | |||
| return Status::OK(); | |||
| auto rq = std::make_shared<BatchFetchRequest>(server_connection_id_, row_id, SupportLocalClient()); | |||
| RETURN_IF_NOT_OK(PushRequest(rq)); | |||
| RETURN_IF_NOT_OK(rq->Wait()); | |||
| int64_t mem_addr; | |||
| Status rc = rq->RestoreRows(out, comm_->SharedMemoryBaseAddr(), &mem_addr); | |||
| // Free the memory by sending a request back to the server. | |||
| if (mem_addr != -1) { | |||
| auto mfree_req = std::make_shared<FreeSharedBlockRequest>(server_connection_id_, mem_addr); | |||
| Status rc2 = PushRequest(mfree_req); | |||
| // But we won't wait for the result for the sake of performance. | |||
| if (rc.IsOk() && rc2.IsError()) { | |||
| rc = rc2; | |||
| } | |||
| } | |||
| return rc; | |||
| } | |||
| Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) { | |||
| @@ -108,40 +124,44 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) { | |||
| // to create a cache and some other tree is trying to use the same cache. | |||
| // That is allowed, however the crc better match! | |||
| if (server_connection_id_) { | |||
| if (cache_crc_ != tree_crc) { | |||
| if (cinfo_.crc() != tree_crc) { | |||
| RETURN_STATUS_UNEXPECTED("Attempt to re-use a cache for a different tree!"); | |||
| } | |||
| // Check the state of the server. For non-mappable case where there is a build phase and a fetch phase, we should | |||
| // skip the build phase. | |||
| lck.Unlock(); // GetStat will grab the mutex again. So unlock it to prevent deadlock. | |||
| CacheClient::ServiceStat stat{}; | |||
| CacheServiceStat stat{}; | |||
| RETURN_IF_NOT_OK(GetStat(&stat)); | |||
| if (stat.cache_service_state == static_cast<uint8_t>(CacheService::State::kFetchPhase)) { | |||
| return Status(StatusCode::kDuplicateKey, __LINE__, __FILE__, "Not an error and we should bypass the build phase"); | |||
| } | |||
| } else { | |||
| cache_crc_ = tree_crc; // It's really a new cache we're creating so save our crc in the client | |||
| // Combine the session and crc. This will form our client cache identifier. | |||
| connection_id_type connection_identification = (static_cast<uint64_t>(session_id_) << 32) | cache_crc_; | |||
| cinfo_.set_crc(tree_crc); // It's really a new cache we're creating so save our crc in the client | |||
| // Now execute the cache create request using this identifier and other configs | |||
| BaseRequest::CreateCacheFlag createFlag = BaseRequest::CreateCacheFlag::kNone; | |||
| CreateCacheRequest::CreateCacheFlag createFlag = CreateCacheRequest::CreateCacheFlag::kNone; | |||
| if (spill_) { | |||
| createFlag |= BaseRequest::CreateCacheFlag::kSpillToDisk; | |||
| createFlag |= CreateCacheRequest::CreateCacheFlag::kSpillToDisk; | |||
| } | |||
| if (generate_id) { | |||
| createFlag |= BaseRequest::CreateCacheFlag::kGenerateRowId; | |||
| createFlag |= CreateCacheRequest::CreateCacheFlag::kGenerateRowId; | |||
| } | |||
| CreationCacheRequest rq(connection_identification, cache_mem_sz_, createFlag); | |||
| RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); | |||
| Status rc = rq.Wait(); | |||
| // Start the comm layer to receive reply | |||
| RETURN_IF_NOT_OK(comm_->ServiceStart()); | |||
| // Initiate connection | |||
| auto rq = std::make_shared<CreateCacheRequest>(cinfo_, cache_mem_sz_, createFlag); | |||
| RETURN_IF_NOT_OK(PushRequest(rq)); | |||
| Status rc = rq->Wait(); | |||
| if (rc.IsOk() || rc.get_code() == StatusCode::kDuplicateKey) { | |||
| server_connection_id_ = rq.GetServerConnectionId(); | |||
| std::string cookie; | |||
| rq->ParseResult(&server_connection_id_, &cookie); | |||
| if (rc.IsOk()) { | |||
| // The 1st guy creating the cache will get a cookie back. | |||
| // But this object may be shared among pipelines and we don't want | |||
| // overwrite it. | |||
| cookie_ = rq.cookie(); | |||
| cookie_ = cookie; | |||
| } | |||
| // Attach to shared memory for local client | |||
| RETURN_IF_NOT_OK(comm_->AttachToSharedMemory(port_, &local_bypass_)); | |||
| } | |||
| // We are not resetting the Duplicate key return code. We are passing it back to the CacheOp. This will tell the | |||
| // CacheOp to bypass the build phase. | |||
| @@ -152,57 +172,57 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) { | |||
| Status CacheClient::PurgeCache() { | |||
| UniqueLock lck(&mux_); | |||
| PurgeCacheRequest rq(server_connection_id_); | |||
| RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); | |||
| return rq.Wait(); | |||
| auto rq = std::make_shared<PurgeCacheRequest>(server_connection_id_); | |||
| RETURN_IF_NOT_OK(PushRequest(rq)); | |||
| RETURN_IF_NOT_OK(rq->Wait()); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheClient::DestroyCache() { | |||
| UniqueLock lck(&mux_); | |||
| DestroyCacheRequest rq(server_connection_id_); | |||
| RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); | |||
| return rq.Wait(); | |||
| auto rq = std::make_shared<DestroyCacheRequest>(server_connection_id_); | |||
| RETURN_IF_NOT_OK(PushRequest(rq)); | |||
| RETURN_IF_NOT_OK(rq->Wait()); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheClient::GetStat(ServiceStat *stat) { | |||
| Status CacheClient::GetStat(CacheServiceStat *stat) { | |||
| SharedLock lck(&mux_); | |||
| RETURN_UNEXPECTED_IF_NULL(stat); | |||
| GetStatRequest rq(server_connection_id_); | |||
| RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); | |||
| RETURN_IF_NOT_OK(rq.Wait()); | |||
| stat->num_disk_cached = rq.GetNumDiskCached(); | |||
| stat->num_mem_cached = rq.GetNumMemCached(); | |||
| stat->min_row_id = rq.GetMinRowId(); | |||
| stat->max_row_id = rq.GetMaxRowId(); | |||
| stat->cache_service_state = rq.GetState(); | |||
| auto rq = std::make_shared<GetStatRequest>(server_connection_id_); | |||
| RETURN_IF_NOT_OK(PushRequest(rq)); | |||
| RETURN_IF_NOT_OK(rq->Wait()); | |||
| rq->GetStat(stat); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheClient::CacheSchema(const std::unordered_map<std::string, int32_t> &map) { | |||
| SharedLock lck(&mux_); | |||
| CacheSchemaRequest rq(server_connection_id_); | |||
| RETURN_IF_NOT_OK(rq.SerializeCacheSchemaRequest(map)); | |||
| RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); | |||
| RETURN_IF_NOT_OK(rq.Wait()); | |||
| auto rq = std::make_shared<CacheSchemaRequest>(server_connection_id_); | |||
| RETURN_IF_NOT_OK(rq->SerializeCacheSchemaRequest(map)); | |||
| RETURN_IF_NOT_OK(PushRequest(rq)); | |||
| RETURN_IF_NOT_OK(rq->Wait()); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheClient::FetchSchema(std::unordered_map<std::string, int32_t> *map) { | |||
| SharedLock lck(&mux_); | |||
| RETURN_UNEXPECTED_IF_NULL(map); | |||
| FetchSchemaRequest rq(server_connection_id_); | |||
| RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); | |||
| RETURN_IF_NOT_OK(rq.Wait()); | |||
| *map = rq.GetColumnMap(); | |||
| auto rq = std::make_shared<FetchSchemaRequest>(server_connection_id_); | |||
| RETURN_IF_NOT_OK(PushRequest(rq)); | |||
| RETURN_IF_NOT_OK(rq->Wait()); | |||
| *map = rq->GetColumnMap(); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheClient::BuildPhaseDone() const { | |||
| SharedLock lck(&mux_); | |||
| BuildPhaseDoneRequest rq(server_connection_id_, cookie()); | |||
| RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); | |||
| RETURN_IF_NOT_OK(rq.Wait()); | |||
| auto rq = std::make_shared<BuildPhaseDoneRequest>(server_connection_id_, cookie()); | |||
| RETURN_IF_NOT_OK(PushRequest(rq)); | |||
| RETURN_IF_NOT_OK(rq->Wait()); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheClient::PushRequest(std::shared_ptr<BaseRequest> rq) const { return comm_->HandleRequest(std::move(rq)); } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -23,9 +23,13 @@ | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/core/config_manager.h" | |||
| #ifdef ENABLE_CACHE | |||
| #include "minddata/dataset/engine/cache/cache_grpc_client.h" | |||
| #else | |||
| #include "minddata/dataset/engine/cache/stub/cache_grpc_client.h" | |||
| #endif | |||
| #include "minddata/dataset/engine/data_buffer.h" | |||
| #include "minddata/dataset/engine/cache/cache_server.h" | |||
| #include "minddata/dataset/engine/cache/de_tensor_generated.h" | |||
| #include "minddata/dataset/util/lock.h" | |||
| namespace mindspore { | |||
| @@ -35,18 +39,120 @@ namespace dataset { | |||
| /// rows, etc. | |||
| class CacheClient { | |||
| public: | |||
| friend class CacheMergeOp; | |||
| /// \brief A builder to help creating a CacheClient object | |||
| class Builder { | |||
| public: | |||
| Builder() : session_id_(0), cache_mem_sz_(0), spill_(false), port_(0), num_workers_(0), prefetch_size_(0) { | |||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||
| hostname_ = "127.0.0.1"; | |||
| port_ = 50052; | |||
| num_workers_ = cfg->num_parallel_workers(); | |||
| prefetch_size_ = 20; // rows_per_buf is too small (1 by default). | |||
| } | |||
| /// Setter function to set the session id | |||
| /// \param session_id | |||
| /// \return Builder object itself. | |||
| Builder &SetSessionId(session_id_type session_id) { | |||
| session_id_ = session_id; | |||
| return *this; | |||
| } | |||
| /// Setter function to set the cache memory size | |||
| /// \param cache_mem_sz | |||
| /// \return Builder object itself | |||
| Builder &SetCacheMemSz(uint64_t cache_mem_sz) { | |||
| cache_mem_sz_ = cache_mem_sz; | |||
| return *this; | |||
| } | |||
| /// Setter function to spill attribute | |||
| /// \param spill | |||
| /// Builder object itself | |||
| Builder &SetSpill(bool spill) { | |||
| spill_ = spill; | |||
| return *this; | |||
| } | |||
| /// Setter function to set rpc hostname | |||
| /// \param host | |||
| /// \return Builder object itself | |||
| Builder &SetHostname(std::string host) { | |||
| hostname_ = std::move(host); | |||
| return *this; | |||
| } | |||
| /// Setter function to set tcpip port | |||
| /// \param port | |||
| /// \return Builder object itself. | |||
| Builder &SetPort(int32_t port) { | |||
| port_ = port; | |||
| return *this; | |||
| } | |||
| /// Setter function to set number of async rpc workers | |||
| /// \param num_workers | |||
| /// \return Builder object itself | |||
| Builder &SetNumWorkers(int32_t num_workers) { | |||
| num_workers_ = num_workers; | |||
| return *this; | |||
| } | |||
| /// Setter function to set prefetch amount for fetching rows from cache server | |||
| /// \param prefetch_sz | |||
| /// \return Builder object itself | |||
| Builder &SetPrefetchSize(int32_t prefetch_sz) { | |||
| prefetch_size_ = prefetch_sz; | |||
| return *this; | |||
| } | |||
| /// Getter functions | |||
| session_id_type getSessionId() const { return session_id_; } | |||
| uint64_t getCacheMemSz() const { return cache_mem_sz_; } | |||
| bool isSpill() const { return spill_; } | |||
| const std::string &getHostname() const { return hostname_; } | |||
| int32_t getPort() const { return port_; } | |||
| int32_t getNumWorkers() const { return num_workers_; } | |||
| int32_t getPrefetchSize() const { return prefetch_size_; } | |||
| Status SanityCheck() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(session_id_ > 0, "session id must be positive"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(cache_mem_sz_ >= 0, "cache memory size must not be negative. (0 implies unlimited"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_workers_ > 0, "rpc workers must be positive"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(prefetch_size_ > 0, "prefetch size must be positive"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!hostname_.empty(), "hostname must not be empty"); | |||
| return Status::OK(); | |||
| } | |||
| Status Build(std::shared_ptr<CacheClient> *out) { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| RETURN_IF_NOT_OK(SanityCheck()); | |||
| *out = std::make_shared<CacheClient>(session_id_, cache_mem_sz_, spill_, hostname_, port_, num_workers_, | |||
| prefetch_size_); | |||
| return Status::OK(); | |||
| } | |||
| private: | |||
| session_id_type session_id_; | |||
| uint64_t cache_mem_sz_; | |||
| bool spill_; | |||
| std::string hostname_; | |||
| int32_t port_; | |||
| int32_t num_workers_; | |||
| int32_t prefetch_size_; | |||
| }; | |||
| /// \brief Constructor | |||
| /// \param session_id A user assigned session id for the current pipeline | |||
| /// \param cache_mem_sz Size of the memory set aside for the row caching. 0 for unlimited | |||
| /// \param spill Spill to disk if out of memory | |||
| CacheClient(uint32_t session_id, uint64_t cache_mem_sz, bool spill); | |||
| CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool spill, std::string hostname, int32_t port, | |||
| int32_t num_workers, int32_t prefetch_size); | |||
| /// \brief Destructor | |||
| ~CacheClient() = default; | |||
| /// \brief Getter function for returning the current session id | |||
| /// \return session id | |||
| uint64_t session_id() const { return session_id_; } | |||
| ~CacheClient() { (void)comm_->ServiceStop(); } | |||
| /// \brief Send a TensorRow to the cache server | |||
| /// \param[in] row | |||
| @@ -83,14 +189,7 @@ class CacheClient { | |||
| /// \brief Get the statistics from a cache. | |||
| /// \param[in/out] Pointer to a pre-allocated ServiceStat object | |||
| /// \return Status object | |||
| struct ServiceStat { | |||
| int64_t num_mem_cached; | |||
| int64_t num_disk_cached; | |||
| row_id_type min_row_id; | |||
| row_id_type max_row_id; | |||
| int8_t cache_service_state; | |||
| }; | |||
| Status GetStat(ServiceStat *); | |||
| Status GetStat(CacheServiceStat *); | |||
| /// \brief Cache the schema at the cache server | |||
| /// \param map The unordered map of the schema | |||
| @@ -122,18 +221,45 @@ class CacheClient { | |||
| /// \return Cookie | |||
| std::string cookie() const { return cookie_; } | |||
| /// \brief Send a request async to the server | |||
| /// \param rq BaseRequest | |||
| /// \return Status object | |||
| Status PushRequest(std::shared_ptr<BaseRequest> rq) const; | |||
| /// \brief If the remote server supports local bypass using shared memory | |||
| /// \return boolean value | |||
| bool SupportLocalClient() const { return local_bypass_; } | |||
| /// \brief Return the base memory address if we attach to any shared memory. | |||
| auto SharedMemoryBaseAddr() const { return comm_->SharedMemoryBaseAddr(); } | |||
| /// Getter functions | |||
| session_id_type session_id() const { return cinfo_.session_id(); } | |||
| uint64_t getCacheMemSz() const { return cache_mem_sz_; } | |||
| bool isSpill() const { return spill_; } | |||
| const std::string &getHostname() const { return hostname_; } | |||
| int32_t getPort() const { return port_; } | |||
| int32_t getNumWorkers() const { return num_workers_; } | |||
| int32_t getPrefetchSize() const { return prefetch_size_; } | |||
| private: | |||
| mutable RWLock mux_; | |||
| uint64_t cache_mem_sz_; | |||
| bool spill_; | |||
| // The session_id_ and cache_crc_ work together to uniquely identify this particular cache and allow | |||
| // sharing of the cache. | |||
| uint32_t session_id_; | |||
| uint32_t cache_crc_; | |||
| CacheClientInfo cinfo_; | |||
| // The server_connection_id_ is the actual id we use for operations after the cache is built | |||
| connection_id_type server_connection_id_; | |||
| // Some magic cookie returned from the cache server. | |||
| std::string cookie_; | |||
| // Comm layer | |||
| bool local_bypass_; | |||
| std::string hostname_; | |||
| int32_t port_; | |||
| int32_t num_workers_; | |||
| int32_t prefetch_size_; | |||
| mutable std::shared_ptr<CacheClientGreeter> comm_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,90 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_ | |||
| /// \note This header file contains common header files and some inlines used by | |||
| /// both client and server side codes. Do not put code that is not common here. | |||
| /// There are client and server specific header files. | |||
| // On platform like Windows, we may support only tcp/ip clients | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| #define CACHE_LOCAL_CLIENT 1 | |||
| #endif | |||
| #ifdef CACHE_LOCAL_CLIENT | |||
| #include <sys/types.h> | |||
| #include <sys/ipc.h> | |||
| #include <sys/shm.h> | |||
| #else | |||
| typedef int key_t; | |||
| #endif | |||
| #ifdef ENABLE_CACHE | |||
| #include <grpcpp/grpcpp.h> | |||
| #endif | |||
| #include <string> | |||
| #ifdef ENABLE_CACHE | |||
| #include "proto/cache_grpc.grpc.pb.h" | |||
| #endif | |||
| #include "proto/cache_grpc.pb.h" | |||
| #include "minddata/dataset/engine/cache/cache_request.h" | |||
| #include "minddata/dataset/engine/cache/de_tensor_generated.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// \brief CacheRow and BatchFetch requests will switch to use shared memory method (if supported | |||
| /// on the platform) when the amount of bytes sent is greater than the following number. | |||
| /// For too small amount, we won't get any benefit using shared memory method because we need | |||
| /// two rpc requests to use shared memory method. | |||
| constexpr static int32_t kLocalByPassThreshold = 64 * 1024; | |||
| /// \brief A flag used by the BatchFetch request (client side) if it can support local bypass | |||
| constexpr static uint32_t kLocalClientSupport = 1; | |||
| /// \brief A flag used by CacheRow request (client side) and BatchFetch (server side) reply to indicate if the data is | |||
| /// inline in the protobuf. This also implies kLocalClientSupport is also true. | |||
| constexpr static uint32_t kDataIsInSharedMemory = 2; | |||
| /// \brief Convert a Status object into a protobuf | |||
| /// \param rc[in] Status object | |||
| /// \param reply[in/out] pointer to pre-allocated protobuf object | |||
| inline void Status2CacheReply(const Status &rc, CacheReply *reply) { | |||
| reply->set_rc(static_cast<google::int32>(rc.get_code())); | |||
| reply->set_msg(rc.ToString()); | |||
| } | |||
| /// \brief Generate the unix socket file we use on both client/server side given a tcp/ip port number | |||
| /// \param port | |||
| /// \return unix socket url | |||
| inline std::string PortToUnixSocketPath(int port) { return "/tmp/cache_server_p" + std::to_string(port); } | |||
| /// \brief Generate a shared memory key using the tcp/ip port. | |||
| /// \note It must be called after the cache server generates the unix socket or ftok will fail. | |||
| /// \note Caller must check the return value. -1 means ftok failed. | |||
| /// \param[in] port | |||
| /// \param[out] err. If not null and ftok fails, this will contain the value of errno | |||
| /// \return key | |||
| inline key_t PortToFtok(int port, int *err) { | |||
| key_t shmkey = -1; | |||
| #ifdef CACHE_LOCAL_CLIENT | |||
| const std::string unix_path = PortToUnixSocketPath(port); | |||
| shmkey = ftok(unix_path.data(), 'a'); | |||
| if (err != nullptr && shmkey == (key_t)-1) { | |||
| *err = errno; | |||
| } | |||
| #endif | |||
| return shmkey; | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_ | |||
| @@ -0,0 +1,151 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/cache/cache_fbb.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// A private function used by SerializeTensorRowHeader to serialize each column in a tensor | |||
| /// \note Not to be called by outside world | |||
| /// \return Status object | |||
| Status SerializeOneTensorMeta(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb, | |||
| const std::shared_ptr<Tensor> &ts_ptr, flatbuffers::Offset<TensorMetaMsg> *out_off) { | |||
| RETURN_UNEXPECTED_IF_NULL(out_off); | |||
| const Tensor *ts = ts_ptr.get(); | |||
| auto shape_off = fbb->CreateVector(ts->shape().AsVector()); | |||
| const auto ptr = ts->GetBuffer(); | |||
| if (ptr == nullptr) { | |||
| RETURN_STATUS_UNEXPECTED("Tensor buffer is null"); | |||
| } | |||
| auto src = ts->type().value(); | |||
| TensorType dest; | |||
| #define CASE(t) \ | |||
| case DataType::t: \ | |||
| dest = TensorType::TensorType_##t; \ | |||
| break | |||
| // Map the type to fill in the flat buffer. | |||
| switch (src) { | |||
| CASE(DE_BOOL); | |||
| CASE(DE_INT8); | |||
| CASE(DE_UINT8); | |||
| CASE(DE_INT16); | |||
| CASE(DE_UINT16); | |||
| CASE(DE_INT32); | |||
| CASE(DE_UINT32); | |||
| CASE(DE_INT64); | |||
| CASE(DE_UINT64); | |||
| CASE(DE_FLOAT16); | |||
| CASE(DE_FLOAT32); | |||
| CASE(DE_FLOAT64); | |||
| CASE(DE_STRING); | |||
| default: | |||
| MS_LOG(ERROR) << "Unknown tensor. Dumping content:\n" << *ts; | |||
| RETURN_STATUS_UNEXPECTED("Unknown type"); | |||
| } | |||
| #undef CASE | |||
| TensorMetaMsgBuilder ts_builder(*fbb); | |||
| ts_builder.add_dims(shape_off); | |||
| ts_builder.add_type(dest); | |||
| auto ts_off = ts_builder.Finish(); | |||
| *out_off = ts_off; | |||
| return Status::OK(); | |||
| } | |||
| Status SerializeTensorRowHeader(const TensorRow &row, std::shared_ptr<flatbuffers::FlatBufferBuilder> *out_fbb) { | |||
| RETURN_UNEXPECTED_IF_NULL(out_fbb); | |||
| auto fbb = std::make_shared<flatbuffers::FlatBufferBuilder>(); | |||
| try { | |||
| fbb = std::make_shared<flatbuffers::FlatBufferBuilder>(); | |||
| std::vector<flatbuffers::Offset<TensorMetaMsg>> v; | |||
| std::vector<int64_t> tensor_sz; | |||
| v.reserve(row.size()); | |||
| tensor_sz.reserve(row.size()); | |||
| // We will go through each column in the row. | |||
| for (const std::shared_ptr<Tensor> &ts_ptr : row) { | |||
| flatbuffers::Offset<TensorMetaMsg> ts_off; | |||
| RETURN_IF_NOT_OK(SerializeOneTensorMeta(fbb, ts_ptr, &ts_off)); | |||
| v.push_back(ts_off); | |||
| tensor_sz.push_back(ts_ptr->SizeInBytes()); | |||
| } | |||
| auto column_off = fbb->CreateVector(v); | |||
| auto data_sz_off = fbb->CreateVector(tensor_sz); | |||
| TensorRowHeaderMsgBuilder row_builder(*fbb); | |||
| row_builder.add_column(column_off); | |||
| row_builder.add_data_sz(data_sz_off); | |||
| // Pass the row_id even if it may not be known. | |||
| row_builder.add_row_id(row.getId()); | |||
| row_builder.add_size_of_this(-1); // fill in later after we call Finish. | |||
| auto out = row_builder.Finish(); | |||
| fbb->Finish(out); | |||
| // Now go back to fill in size_of_this in the flat buffer. | |||
| auto msg = GetMutableTensorRowHeaderMsg(fbb->GetBufferPointer()); | |||
| auto success = msg->mutate_size_of_this(fbb->GetSize()); | |||
| if (!success) { | |||
| RETURN_STATUS_UNEXPECTED("Unable to set size_of_this"); | |||
| } | |||
| (*out_fbb) = std::move(fbb); | |||
| return Status::OK(); | |||
| } catch (const std::bad_alloc &e) { | |||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); | |||
| } | |||
| } | |||
| Status RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, std::shared_ptr<Tensor> *out) { | |||
| RETURN_UNEXPECTED_IF_NULL(col_ts); | |||
| auto shape_in = col_ts->dims(); | |||
| auto type_in = col_ts->type(); | |||
| std::vector<dsize_t> v; | |||
| v.reserve(shape_in->size()); | |||
| v.assign(shape_in->begin(), shape_in->end()); | |||
| TensorShape shape(v); | |||
| DataType::Type dest = DataType::DE_UNKNOWN; | |||
| #define CASE(t) \ | |||
| case TensorType_##t: \ | |||
| dest = DataType::Type::t; \ | |||
| break | |||
| switch (type_in) { | |||
| CASE(DE_BOOL); | |||
| CASE(DE_INT8); | |||
| CASE(DE_UINT8); | |||
| CASE(DE_INT16); | |||
| CASE(DE_UINT16); | |||
| CASE(DE_INT32); | |||
| CASE(DE_UINT32); | |||
| CASE(DE_INT64); | |||
| CASE(DE_UINT64); | |||
| CASE(DE_FLOAT16); | |||
| CASE(DE_FLOAT32); | |||
| CASE(DE_FLOAT64); | |||
| CASE(DE_STRING); | |||
| } | |||
| #undef CASE | |||
| DataType type(dest); | |||
| std::shared_ptr<Tensor> ts; | |||
| RETURN_IF_NOT_OK( | |||
| Tensor::CreateFromMemory(shape, type, static_cast<const unsigned char *>(data.GetPointer()), data.GetSize(), &ts)); | |||
| // Next we restore the real data which can be embedded or stored separately. | |||
| if (ts->SizeInBytes() != data.GetSize()) { | |||
| MS_LOG(ERROR) << "Unexpected length. Read " << data.GetSize() << ". Expected " << ts->SizeInBytes() << ".\n" | |||
| << "Dumping tensor\n" | |||
| << *ts << "\n"; | |||
| RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details."); | |||
| } | |||
| *out = std::move(ts); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,46 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_FBB_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_FBB_H_ | |||
| /// This header contains some serialize and deserialize functions for tensor row using | |||
| /// Google Flatbuffer | |||
| #include <memory> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/cache/de_tensor_generated.h" | |||
| #include "minddata/dataset/core/tensor_row.h" | |||
| #include "minddata/dataset/util/slice.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// \brief Function to serialize TensorRow header used by CacheRowRequest | |||
| /// \param row TensorRow | |||
| /// \param fbb [in/out] fbb that contains the serialized data | |||
| /// \return Status object | |||
| Status SerializeTensorRowHeader(const TensorRow &row, std::shared_ptr<flatbuffers::FlatBufferBuilder> *fbb); | |||
| /// \brief A function used by BatchFetchRequest to deserialize a flat buffer back to a tensor row. | |||
| /// \param col_ts A serialized version of Tensor meta data | |||
| /// \param data Tensor data wrapped in a slice | |||
| /// \param out Tensor | |||
| /// \return Status object | |||
| Status RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, std::shared_ptr<Tensor> *out); | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_FBB_H_ | |||
| @@ -0,0 +1,54 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| syntax = "proto3"; | |||
| package mindspore.dataset; | |||
| option cc_enable_arenas = true; | |||
| // The session_id and crc work together to uniquely identify this particular cache and allow | |||
| // sharing of the cache. | |||
| message CacheClientInfo { | |||
| uint32 session_id = 1; | |||
| uint32 crc = 2; | |||
| } | |||
| message CacheRequest { | |||
| // Type of rpc request | |||
| int32 type = 1; | |||
| // Extra optional flag used by individual request if needed | |||
| uint32 flag = 2; | |||
| oneof connect_info { | |||
| // The server_connection_id is the actual id we use for operations after the cache is built | |||
| int64 connection_id = 3; | |||
| // But some request like CreateCache we have to use the session id and crc to connect to the server. | |||
| CacheClientInfo connection_info = 4; | |||
| } | |||
| // Everything else is just vector of buffers | |||
| repeated bytes buf_data = 5; | |||
| } | |||
| message CacheReply { | |||
| int32 rc = 1; | |||
| string msg = 2; | |||
| // Extra optional flag used by individual request if needed | |||
| uint32 flag = 3; | |||
| // What the server send back is a plain buffer | |||
| bytes result = 4; | |||
| } | |||
| service CacheServerGreeter { | |||
| rpc CacheServerRequest (CacheRequest) returns (CacheReply) {} | |||
| } | |||
| @@ -0,0 +1,161 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/cache/cache_grpc_client.h" | |||
| #include <chrono> | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| Status CacheClientRequestTag::MakeCall(CacheServerGreeter::Stub *stub, grpc::CompletionQueue *cq, | |||
| std::unique_ptr<CacheClientRequestTag> &&tag) { | |||
| // If there is anything extra we need to do before we send. | |||
| RETURN_IF_NOT_OK(tag->base_rq_->Prepare()); | |||
| // One minute timeout | |||
| auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60); | |||
| tag->ctx_.set_deadline(deadline); | |||
| tag->rpc_ = stub->PrepareAsyncCacheServerRequest(&tag->ctx_, tag->base_rq_->rq_, cq); | |||
| tag->rpc_->StartCall(); | |||
| // Last step is we release the ownership and transfer it to the completion queue. | |||
| // The memory will be released by WorkerEntry or by the destructor when we drain the queue | |||
| auto ccReqTag = tag.release(); | |||
| ccReqTag->rpc_->Finish(&ccReqTag->base_rq_->reply_, &ccReqTag->rc_, | |||
| ccReqTag); // inject this object into the completion queue | |||
| return Status::OK(); | |||
| } | |||
| CacheClientGreeter::~CacheClientGreeter() { | |||
| (void)ServiceStop(); | |||
| // Detach from shared memory if any | |||
| if (shmat_addr_ != nullptr) { | |||
| shmdt(shmat_addr_); | |||
| shmat_addr_ = nullptr; | |||
| } | |||
| } | |||
| CacheClientGreeter::CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_workers) | |||
| : num_workers_(num_workers), shm_key_(-1), shm_id_(-1), shmat_addr_(nullptr) { | |||
| grpc::ChannelArguments args; | |||
| // We need to bump up the message size to unlimited. The default receiving | |||
| // message limit is 4MB which is not big enough. | |||
| args.SetMaxReceiveMessageSize(-1); | |||
| #if CACHE_LOCAL_CLIENT | |||
| // Try connect locally to the unix_socket first as the first preference | |||
| // Need to resolve hostname to ip address rather than to do a string compare | |||
| if (hostname == "127.0.0.1") { | |||
| std::string target = "unix://" + PortToUnixSocketPath(port); | |||
| channel_ = grpc::CreateCustomChannel(target, grpc::InsecureChannelCredentials(), args); | |||
| } else { | |||
| #endif | |||
| std::string target = hostname + ":" + std::to_string(port); | |||
| channel_ = grpc::CreateCustomChannel(target, grpc::InsecureChannelCredentials(), args); | |||
| #if CACHE_LOCAL_CLIENT | |||
| } | |||
| #endif | |||
| stub_ = CacheServerGreeter::NewStub(channel_); | |||
| } | |||
| Status CacheClientGreeter::AttachToSharedMemory(int32_t port, bool *local_bypass) { | |||
| *local_bypass = false; | |||
| #if CACHE_LOCAL_CLIENT | |||
| int err; | |||
| shm_key_ = PortToFtok(port, &err); | |||
| if (shm_key_ == (key_t)-1) { | |||
| std::string errMsg = "Ftok failed with errno " + std::to_string(err); | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| // Attach to the shared memory | |||
| shm_id_ = shmget(shm_key_, 0, 0); | |||
| if (shm_id_ == -1) { | |||
| RETURN_STATUS_UNEXPECTED("Shmget failed. Errno " + std::to_string(errno)); | |||
| } | |||
| shmat_addr_ = shmat(shm_id_, nullptr, 0); | |||
| if (shmat_addr_ == reinterpret_cast<void *>(-1)) { | |||
| RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno)); | |||
| } | |||
| *local_bypass = true; | |||
| #endif | |||
| return Status::OK(); | |||
| } | |||
| Status CacheClientGreeter::DoServiceStart() { | |||
| RETURN_IF_NOT_OK(vg_.ServiceStart()); | |||
| RETURN_IF_NOT_OK(DispatchWorkers(num_workers_)); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheClientGreeter::DoServiceStop() { | |||
| // Shutdown the queue. We don't accept any more new incomers. | |||
| cq_.Shutdown(); | |||
| // Shutdown the TaskGroup. | |||
| vg_.interrupt_all(); | |||
| vg_.join_all(Task::WaitFlag::kNonBlocking); | |||
| // Drain the queue | |||
| bool success; | |||
| void *tag; | |||
| while (cq_.Next(&tag, &success)) { | |||
| auto r = reinterpret_cast<CacheClientRequestTag *>(tag); | |||
| delete r; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheClientGreeter::HandleRequest(std::shared_ptr<BaseRequest> rq) { | |||
| auto tag = std::make_unique<CacheClientRequestTag>(std::move(rq)); | |||
| return tag->MakeCall(stub_.get(), &cq_, std::move(tag)); | |||
| } | |||
| Status CacheClientGreeter::WorkerEntry() { | |||
| TaskManager::FindMe()->Post(); | |||
| do { | |||
| bool success; | |||
| void *tag; | |||
| auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(1); | |||
| // Set a timeout for one second. Check for interrupt if we need to do early exit. | |||
| auto r = cq_.AsyncNext(&tag, &success, deadline); | |||
| if (r == grpc_impl::CompletionQueue::NextStatus::GOT_EVENT) { | |||
| auto rq = reinterpret_cast<CacheClientRequestTag *>(tag); | |||
| if (success) { | |||
| auto &rc = rq->rc_; | |||
| if (!rc.ok()) { | |||
| auto error_code = rq->rc_.error_code(); | |||
| std::string errMsg = rq->rc_.error_message() + ". GRPC Code " + std::to_string(error_code); | |||
| Status remote_rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| Status2CacheReply(remote_rc, &rq->base_rq_->reply_); | |||
| } | |||
| // Notify the waiting thread. | |||
| rq->Notify(); | |||
| } | |||
| // We can now free the memory | |||
| delete rq; | |||
| } else if (r == grpc_impl::CompletionQueue::NextStatus::TIMEOUT) { | |||
| // If we are interrupted, exit. Otherwise wait again. | |||
| RETURN_IF_INTERRUPTED(); | |||
| } else { | |||
| // Queue is drained. | |||
| break; | |||
| } | |||
| } while (true); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheClientGreeter::DispatchWorkers(int32_t num_workers) { | |||
| auto f = std::bind(&CacheClientGreeter::WorkerEntry, this); | |||
| for (auto i = 0; i < num_workers; ++i) { | |||
| RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Async reply", f)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,102 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include "minddata/dataset/engine/cache/cache_common.h" | |||
| #include "minddata/dataset/util/service.h" | |||
| #include "minddata/dataset/util/task_manager.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// \brief A client view of gRPC request | |||
| /// Like the class CacheServerRequest, this is used as a tag to inject into the gRPC | |||
| /// completion queue. The thread that makes the rpc request will wait on a wait post | |||
| /// area for the reply to come back. Since this tag will be deleted from memory and | |||
| /// we thus we need to work on a shared pointer of the BaseRequest such that its | |||
| /// use count is at least two. Otherwise either thread will be referencing stale memory. | |||
| /// \see CacheServerRequest | |||
| class CacheClientRequestTag { | |||
| public: | |||
| friend class CacheClientGreeter; | |||
| explicit CacheClientRequestTag(std::shared_ptr<BaseRequest> rq) : base_rq_(std::move(rq)) {} | |||
| ~CacheClientRequestTag() = default; | |||
| /// \brief Make a RPC call | |||
| /// \param stub from CacheClientGreeter | |||
| /// \param cq from CacheClientGreeter | |||
| /// \return Status object | |||
| static Status MakeCall(CacheServerGreeter::Stub *stub, grpc::CompletionQueue *cq, | |||
| std::unique_ptr<CacheClientRequestTag> &&tag); | |||
| /// \brief Notify the client that a result has come back from the server | |||
| void Notify() { base_rq_->wp_.Set(); } | |||
| private: | |||
| std::shared_ptr<BaseRequest> base_rq_; | |||
| grpc::Status rc_; | |||
| grpc::ClientContext ctx_; | |||
| std::unique_ptr<grpc::ClientAsyncResponseReader<CacheReply>> rpc_; | |||
| }; | |||
| /// \brief A GRPC layer to convert BaseRequest into protobuf and send to the cache server using gRPC | |||
| /// \see BaseRequest | |||
| class CacheClientGreeter : public Service { | |||
| friend class CacheClient; | |||
| public: | |||
| explicit CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_workers); | |||
| ~CacheClientGreeter(); | |||
| /// Override base Service class | |||
| Status DoServiceStart() override; | |||
| Status DoServiceStop() override; | |||
| /// \brief Send the request to the server | |||
| /// \return Status object | |||
| Status HandleRequest(std::shared_ptr<BaseRequest> rq); | |||
| /// \brief A handful of threads will be handling async reply from the server | |||
| /// \return | |||
| Status WorkerEntry(); | |||
| /// \brief Kick off threads to receive reply from the server | |||
| Status DispatchWorkers(int32_t num_workers); | |||
| /// \brief Attach to shared memory for local client | |||
| /// \note Called after we have established a connection. | |||
| /// \return Status object. | |||
| Status AttachToSharedMemory(int32_t port, bool *local_bypass); | |||
| /// \brief This returns where we attach to the shared memory. | |||
| /// \return Base address of the shared memory. | |||
| const void *SharedMemoryBaseAddr() const { return shmat_addr_; } | |||
| private: | |||
| std::shared_ptr<grpc::Channel> channel_; | |||
| std::unique_ptr<CacheServerGreeter::Stub> stub_; | |||
| grpc::CompletionQueue cq_; | |||
| TaskGroup vg_; | |||
| int32_t num_workers_; | |||
| key_t shm_key_; | |||
| int32_t shm_id_; | |||
| void *shmat_addr_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_ | |||
| @@ -0,0 +1,203 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/cache/cache_grpc_server.h" | |||
| #include <limits> | |||
| #include "minddata/dataset/engine/cache/cache_server.h" | |||
| #include "minddata/dataset/util/path.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| CacheServerGreeterImpl::CacheServerGreeterImpl(int32_t port, int32_t shared_memory_sz_in_gb) | |||
| : port_(port), shm_pool_sz_in_gb_(shared_memory_sz_in_gb) { | |||
| // Setup a path for unix socket. | |||
| unix_socket_ = PortToUnixSocketPath(port); | |||
| // We can't generate the ftok key yet until the unix_socket_ is created | |||
| } | |||
| void CacheServerGreeterImpl::Shutdown() { | |||
| if (server_) { | |||
| auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(1); | |||
| server_->Shutdown(deadline); | |||
| } | |||
| // Always shutdown the completion queue after the server. | |||
| if (cq_) { | |||
| cq_->Shutdown(); | |||
| // We need to drain the queue. All the tag is coming from | |||
| // the Services pool which will be shutdown as well. So we | |||
| // ignore the tag. | |||
| void *tag; | |||
| bool success; | |||
| while (cq_->Next(&tag, &success)) { | |||
| } | |||
| } | |||
| } | |||
| CacheServerGreeterImpl::~CacheServerGreeterImpl() { Shutdown(); } | |||
| Status CacheServerGreeterImpl::IpcResourceCleanup() { | |||
| #if CACHE_LOCAL_CLIENT | |||
| int err; | |||
| auto shm_key = PortToFtok(port_, &err); | |||
| // We are expecting the unix path doesn't exist. | |||
| if (shm_key == (key_t)-1) { | |||
| return Status::OK(); | |||
| } | |||
| // Attach to the shared memory | |||
| auto shm_id = shmget(shm_key, 0, 0); | |||
| if (shm_id == -1) { | |||
| return Status::OK(); | |||
| } | |||
| struct shmid_ds ds {}; | |||
| auto inx = shmctl(shm_id, IPC_STAT, &ds); | |||
| if (inx == -1) { | |||
| std::string errMsg = "Unable to query shared memory with id " + std::to_string(shm_id); | |||
| errMsg += "\nPlesae remove it manually using ipcrm -m command"; | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| if (ds.shm_nattch == 0) { | |||
| // Stale shared memory from last time. | |||
| // Remove both the memory and the socket path | |||
| inx = shmctl(shm_id, IPC_RMID, nullptr); | |||
| if (inx == -1) { | |||
| std::string errMsg = "Unable to remove shared memory with id " + std::to_string(shm_id); | |||
| errMsg += ". Errno :" + std::to_string(errno); | |||
| errMsg += "\nPlesae remove it manually using ipcrm -m command"; | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| Path p(unix_socket_); | |||
| (void)p.Remove(); | |||
| } else { | |||
| // Server is already up. | |||
| MS_LOG(ERROR) << "Cache server is already up and running"; | |||
| // We return a duplicate error. The main() will intercept | |||
| // and output a proper message | |||
| return Status(StatusCode::kDuplicateKey); | |||
| } | |||
| #endif | |||
| return Status::OK(); | |||
| } | |||
| Status CacheServerGreeterImpl::Run() { | |||
| // To listen on all interfaces, use 0.0.0.0 | |||
| // Use 127.0.0.1 if just locally on the same machine. | |||
| std::string host("0.0.0.0"); // listen on all interfaces. | |||
| std::string server_address = host + ":" + std::to_string(port_); | |||
| grpc::ServerBuilder builder; | |||
| // Default message size for gRPC is 4MB. Increase it to 2g-1 | |||
| builder.SetMaxReceiveMessageSize(std::numeric_limits<int32_t>::max()); | |||
| int port_tcpip = 0; | |||
| #if CACHE_LOCAL_CLIENT | |||
| int port_local = 0; | |||
| // Check if we need to do clean up on the shared memory if the server | |||
| // came down unexpectedly like SEGV | |||
| RETURN_IF_NOT_OK(IpcResourceCleanup()); | |||
| // We also optimize on local clients on the same machine using unix socket | |||
| builder.AddListeningPort("unix://" + unix_socket_, grpc::InsecureServerCredentials(), &port_local); | |||
| #endif | |||
| builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_tcpip); | |||
| builder.RegisterService(&svc_); | |||
| cq_ = builder.AddCompletionQueue(); | |||
| server_ = builder.BuildAndStart(); | |||
| if (server_) { | |||
| MS_LOG(INFO) << "Server listening on " << server_address; | |||
| #if CACHE_LOCAL_CLIENT | |||
| RETURN_IF_NOT_OK(CachedSharedMemoryArena::CreateArena(&shm_pool_, port_, shm_pool_sz_in_gb_)); | |||
| MS_LOG(INFO) << "Creation of local socket and shared memory successful"; | |||
| #endif | |||
| } else { | |||
| std::string errMsg = "Fail to start server. "; | |||
| if (port_tcpip != port_) { | |||
| errMsg += "Unable to bind to tcpip port " + std::to_string(port_) + "."; | |||
| } | |||
| #if CACHE_LOCAL_CLIENT | |||
| if (port_local == 0) { | |||
| errMsg += " Unable to create unix socket " + unix_socket_ + "."; | |||
| } | |||
| #endif | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheServerGreeterImpl::HandleRequest(int32_t worker_id) { | |||
| bool success; | |||
| void *tag; | |||
| // We loop through the grpc queue. Each connection if successful | |||
| // will come back with our own tag which is an instance of CacheServerRequest | |||
| // and we simply call its functor. But first we need to create these instances | |||
| // and inject them into the grpc queue. | |||
| CacheServerRequest *p; | |||
| // Get a free tag from my free list. | |||
| RETURN_IF_NOT_OK(CacheServer::GetFreeRequestTag(worker_id, &p)); | |||
| RETURN_IF_NOT_OK((*p)(&svc_, cq_.get())); | |||
| do { | |||
| auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(1); | |||
| // Set a timeout for one second. Check for interrupt if we need to do early exit. | |||
| auto r = cq_->AsyncNext(&tag, &success, deadline); | |||
| if (r == grpc_impl::CompletionQueue::NextStatus::GOT_EVENT) { | |||
| if (success) { | |||
| auto rq = static_cast<CacheServerRequest *>(tag); | |||
| RETURN_IF_NOT_OK((*rq)(&svc_, cq_.get())); | |||
| } | |||
| } else if (r == grpc_impl::CompletionQueue::NextStatus::TIMEOUT) { | |||
| // If we are interrupted, exit. Otherwise wait again. | |||
| RETURN_IF_INTERRUPTED(); | |||
| } else { | |||
| // Queue is drained. | |||
| break; | |||
| } | |||
| } while (true); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheServerRequest::operator()(CacheServerGreeter::AsyncService *svc, grpc::ServerCompletionQueue *cq) { | |||
| auto myQID = getQid(); | |||
| if (st_ == STATE::CREATE) { | |||
| st_ = STATE::PROCESS; | |||
| svc->RequestCacheServerRequest(&ctx_, &rq_, &responder_, cq, cq, this); | |||
| } else if (st_ == STATE::PROCESS) { | |||
| // Get a new tag and handle the next request before we serve the current request. | |||
| // The tag will be recycled when its state is changed to FINISH | |||
| CacheServerRequest *next_rq; | |||
| RETURN_IF_NOT_OK(CacheServer::GetFreeRequestTag(myQID, &next_rq)); | |||
| RETURN_IF_NOT_OK((*next_rq)(svc, cq)); | |||
| // Now we continue with the current request. | |||
| // First thing we need to extract the type from the incoming request. | |||
| // When this object was first created (i.e. STATE::CREATE), we set the type to UNKNOWN. | |||
| type_ = static_cast<RequestType>(rq_.type()); | |||
| // Now we pass the address of this instance to CacheServer's main loop. | |||
| MS_LOG(DEBUG) << "Handle request " << *this; | |||
| auto &cs = CacheServer::GetInstance(); | |||
| RETURN_IF_NOT_OK(cs.PushRequest(myQID, this)); | |||
| } else if (st_ == STATE::FINISH) { | |||
| MS_LOG(DEBUG) << *this << " Finished."; | |||
| // Return back to the free list. | |||
| RETURN_IF_NOT_OK(CacheServer::ReturnRequestTag(this)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| void CacheServerRequest::Print(std::ostream &out) const { | |||
| if (rq_.has_connection_info()) { | |||
| out << "Session Id: " << rq_.connection_info().session_id() << " CRC: " << rq_.connection_info().crc(); | |||
| } else { | |||
| out << "Connection Id: " << rq_.connection_id(); | |||
| } | |||
| out << " "; | |||
| BaseRequest::Print(out); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,103 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/cache/cache_common.h" | |||
| #include "minddata/dataset/engine/cache/cache_arena.h" | |||
| #include "minddata/dataset/util/allocator.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| #include "minddata/dataset/util/task_manager.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// \brief Server side view of BaseRequest. Incoming request are in the form of protobuf objects | |||
| /// and this class is used to translate from protobuf to structures understood by CacheService class. | |||
| /// \see CacheService | |||
| class CacheServerRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| enum class STATE : int8_t { CREATE = 1, PROCESS = 2, FINISH = 3 }; | |||
| explicit CacheServerRequest(int32_t queue_id) | |||
| : BaseRequest::BaseRequest(BaseRequest::RequestType::kRequestUnknown), | |||
| qid_(queue_id), | |||
| st_(STATE::CREATE), | |||
| responder_(&ctx_) {} | |||
| ~CacheServerRequest() = default; | |||
| /// \brief Functor. Used mainly by CacheServerGreeterImpl class to tag each incoming request and this | |||
| /// functor will translate each protobuf into some form understood by by CacheService class. | |||
| /// \param svc Async service | |||
| /// \param cq Completion queue | |||
| /// \return Status object | |||
| Status operator()(CacheServerGreeter::AsyncService *svc, grpc::ServerCompletionQueue *cq); | |||
| /// \brief Override the base class Print method | |||
| /// \param out | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Getter of the queue id | |||
| /// \return The queue where the request should go to | |||
| int32_t getQid() const { return qid_; } | |||
| private: | |||
| int32_t qid_; | |||
| Status rc_; | |||
| STATE st_; | |||
| grpc::ServerContext ctx_; | |||
| grpc::ServerAsyncResponseWriter<CacheReply> responder_; | |||
| }; | |||
| /// \brief Implementation of CacheServerGreeter | |||
| /// \note It is an async server | |||
| /// \see cache_grpc.proto | |||
| class CacheServerGreeterImpl final { | |||
| friend class CacheServer; | |||
| public: | |||
| explicit CacheServerGreeterImpl(int32_t port, int32_t shared_memory_sz_in_gb); | |||
| virtual ~CacheServerGreeterImpl(); | |||
| /// \brief Brings up gRPC server | |||
| /// \return none | |||
| Status Run(); | |||
| /// \brief Entry function to handle cache server request | |||
| Status HandleRequest(int32_t worker_id); | |||
| /// Return the shared memory pool. | |||
| /// \return Return the shared memory pool | |||
| CachedSharedMemoryArena *GetSharedMemoryPool() { return shm_pool_.get(); } | |||
| void Shutdown(); | |||
| Status IpcResourceCleanup(); | |||
| private: | |||
| int32_t port_; | |||
| size_t shm_pool_sz_in_gb_; | |||
| std::string unix_socket_; | |||
| CacheServerGreeter::AsyncService svc_; | |||
| std::unique_ptr<grpc::ServerCompletionQueue> cq_; | |||
| std::unique_ptr<grpc::Server> server_; | |||
| std::unique_ptr<CachedSharedMemoryArena> shm_pool_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_ | |||
| @@ -0,0 +1,121 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/cache/cache_server.h" | |||
| #include <sys/types.h> | |||
| #include <unistd.h> | |||
| #ifdef USE_GLOG | |||
| #include <glog/logging.h> | |||
| #endif | |||
| #include <cstdlib> | |||
| namespace ds = mindspore::dataset; | |||
| int main(int argc, char **argv) { | |||
| ds::Status rc; | |||
| ds::CacheServer::Builder builder; | |||
| // This executable is not to be called directly, and should be invoked by cache_admin executable. | |||
| if (argc != 7) { | |||
| rc = ds::Status(ds::StatusCode::kSyntaxError); | |||
| std::cerr << rc.ToString() << std::endl; | |||
| return static_cast<int>(rc.get_code()); | |||
| } | |||
| builder.SetRootDirectory(argv[1]) | |||
| .SetNumWorkers(strtol(argv[2], nullptr, 10)) | |||
| .SetPort(strtol(argv[3], nullptr, 10)) | |||
| .SetSharedMemorySizeInGB(strtol(argv[4], nullptr, 10)); | |||
| #ifdef USE_GLOG | |||
| FLAGS_minloglevel = strtol(argv[5], nullptr, 10); | |||
| #endif | |||
| auto daemonize_string = argv[6]; | |||
| bool daemonize = strcmp(daemonize_string, "true") == 0 || strcmp(daemonize_string, "TRUE") == 0 || | |||
| strcmp(daemonize_string, "t") == 0 || strcmp(daemonize_string, "T") == 0; | |||
| // We always change directory to / on unix rather than using the directory where the cache_server | |||
| // is called. This is a standard procedure for daemonize a process on unix. | |||
| if (chdir("/") == -1) { | |||
| std::string errMsg = "Unable to change directory to /. Errno = " + std::to_string(errno); | |||
| std::cerr << errMsg << std::endl; | |||
| return -1; | |||
| } | |||
| // Simple check of the parameters before we move on. | |||
| rc = builder.SanityCheck(); | |||
| if (rc.IsError()) { | |||
| std::cerr << rc.ToString() << std::endl; | |||
| return static_cast<int>(rc.get_code()); | |||
| } | |||
| #ifdef USE_GLOG | |||
| FLAGS_log_dir = "/tmp"; | |||
| google::InitGoogleLogging(argv[0]); | |||
| #endif | |||
| if (daemonize) { | |||
| // fork the child process to become the daemon | |||
| pid_t pid = fork(); | |||
| // failed to fork | |||
| if (pid < 0) { | |||
| std::string err_msg = "Failed to fork process for cache server: " + std::to_string(errno); | |||
| std::cerr << err_msg << std::endl; | |||
| return errno; | |||
| } else if (pid > 0) { | |||
| // Parent | |||
| std::cerr << "cache server daemon process has been created as process id: " << pid | |||
| << "\nCheck log file for any start up error" << std::endl; | |||
| signal(SIGCHLD, SIG_IGN); // ignore sig child signal. | |||
| return 0; | |||
| } else { | |||
| // Child process will continue from here if daemonize and parent has already exited. | |||
| // If we are running in the foreground, none of the code in block below will be run. | |||
| pid_t sid; | |||
| umask(0); | |||
| sid = setsid(); | |||
| if (sid < 0) { | |||
| MS_LOG(ERROR) << "Failed to setsid(). Errno = " << std::to_string(errno); | |||
| return errno; | |||
| } | |||
| close(0); | |||
| close(1); | |||
| close(2); | |||
| } | |||
| } | |||
| // Dump the summary | |||
| MS_LOG(INFO) << builder << std::endl; | |||
| rc = builder.Build(); | |||
| if (rc.IsOk()) { | |||
| ds::CacheServer &cs = ds::CacheServer::GetInstance(); | |||
| // Kick off the threads. Loop forever and never return unless error. | |||
| rc = cs.Run(); | |||
| if (rc.get_code() == ds::StatusCode::kDuplicateKey) { | |||
| std::string errMsg = "Server is already started"; | |||
| MS_LOG(ERROR) << errMsg; | |||
| std::cerr << errMsg << std::endl; | |||
| return 0; | |||
| } | |||
| } | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << rc.ToString(); | |||
| std::cerr << rc.ToString() << std::endl; | |||
| return static_cast<int>(rc.get_code()); | |||
| } | |||
| return 0; | |||
| } | |||
| @@ -14,154 +14,149 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/cache/cache_request.h" | |||
| #include <cstdlib> | |||
| #include <thread> | |||
| #include "minddata/dataset/core/constants.h" | |||
| #include "minddata/dataset/engine/cache/cache_client.h" | |||
| #include "minddata/dataset/engine/cache/cache_fbb.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| Status CacheRowRequest::SerializeCacheRowRequest(const TensorRow &row) { | |||
| buffers_.reserve(row.size() + 1); | |||
| RETURN_IF_NOT_OK(SerializeTensorRowHeader(row)); | |||
| buffers_.push_back(fbb_->GetBufferPointer()); | |||
| for (const auto &ts : row) { | |||
| buffers_.push_back(ts->GetBuffer()); | |||
| } | |||
| Status BaseRequest::Wait() { | |||
| RETURN_IF_NOT_OK(wp_.Wait()); | |||
| Status remote_rc(static_cast<StatusCode>(reply_.rc()), reply_.msg()); | |||
| RETURN_IF_NOT_OK(remote_rc); | |||
| // Any extra work to do before we return back to the client. | |||
| RETURN_IF_NOT_OK(PostReply()); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheRowRequest::SerializeTensorRowHeader(const TensorRow &row) { | |||
| try { | |||
| fbb_ = std::make_shared<flatbuffers::FlatBufferBuilder>(); | |||
| std::vector<flatbuffers::Offset<TensorMetaMsg>> v; | |||
| std::vector<int64_t> tensor_sz; | |||
| v.reserve(row.size()); | |||
| tensor_sz.reserve(row.size()); | |||
| // We will go through each column in the row. | |||
| for (const std::shared_ptr<Tensor> &ts_ptr : row) { | |||
| flatbuffers::Offset<TensorMetaMsg> ts_off; | |||
| RETURN_IF_NOT_OK(SerializeOneTensorMeta(ts_ptr, &ts_off)); | |||
| v.push_back(ts_off); | |||
| tensor_sz.push_back(ts_ptr->SizeInBytes()); | |||
| Status CacheRowRequest::SerializeCacheRowRequest(const CacheClient *cc, const TensorRow &row) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(row.size() > 0, "Empty tensor row"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(cc->SupportLocalClient() == support_local_bypass_, "Local bypass mismatch"); | |||
| // Calculate how many bytes (not counting the cookie) we are sending to the server. We only | |||
| // use shared memory (if supported) if we exceed certain amount | |||
| std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb; | |||
| RETURN_IF_NOT_OK(::mindspore::dataset::SerializeTensorRowHeader(row, &fbb)); | |||
| sz_ += fbb->GetSize(); | |||
| for (const auto &ts : row) { | |||
| sz_ += ts->SizeInBytes(); | |||
| } | |||
| bool sent_using_local_bypass = support_local_bypass_ ? (sz_ >= kLocalByPassThreshold) : false; | |||
| uint32_t flag = 0; | |||
| if (support_local_bypass_) { | |||
| BitSet(&flag, kLocalClientSupport); | |||
| } | |||
| if (sent_using_local_bypass) { | |||
| BitSet(&flag, kDataIsInSharedMemory); | |||
| } | |||
| rq_.set_flag(flag); | |||
| if (sent_using_local_bypass) { | |||
| MS_LOG(DEBUG) << "Requesting " << sz_ << " bytes of shared memory data"; | |||
| // Allocate shared memory from the server | |||
| auto mem_rq = std::make_shared<AllocateSharedBlockRequest>(rq_.connection_id(), sz_); | |||
| RETURN_IF_NOT_OK(cc->PushRequest(mem_rq)); | |||
| RETURN_IF_NOT_OK(mem_rq->Wait()); | |||
| addr_ = mem_rq->GetAddr(); | |||
| // Now we need to add that to the base address of where we attach. | |||
| auto base = cc->SharedMemoryBaseAddr(); | |||
| auto p = reinterpret_cast<void *>(reinterpret_cast<int64_t>(base) + addr_); | |||
| // Now we copy the data onto shared memory. | |||
| WritableSlice all(p, sz_); | |||
| auto offset = fbb->GetSize(); | |||
| ReadableSlice header(fbb->GetBufferPointer(), fbb->GetSize()); | |||
| Status copy_rc; | |||
| copy_rc = WritableSlice::Copy(&all, header); | |||
| if (copy_rc.IsOk()) { | |||
| for (const auto &ts : row) { | |||
| WritableSlice row_data(all, offset, ts->SizeInBytes()); | |||
| ReadableSlice src(ts->GetBuffer(), ts->SizeInBytes()); | |||
| copy_rc = WritableSlice::Copy(&row_data, src); | |||
| if (copy_rc.IsError()) { | |||
| break; | |||
| } | |||
| offset += ts->SizeInBytes(); | |||
| } | |||
| // Fill in where to find the data | |||
| AddDataLocation(); | |||
| } | |||
| auto column_off = fbb_->CreateVector(v); | |||
| auto data_sz_off = fbb_->CreateVector(tensor_sz); | |||
| TensorRowHeaderMsgBuilder row_builder(*fbb_); | |||
| row_builder.add_column(column_off); | |||
| row_builder.add_data_sz(data_sz_off); | |||
| // Pass the row_id even if it may not be known. | |||
| row_builder.add_row_id(row.getId()); | |||
| row_builder.add_size_of_this(-1); // fill in later after we call Finish. | |||
| auto out = row_builder.Finish(); | |||
| fbb_->Finish(out); | |||
| // Now go back to fill in size_of_this in the flat buffer. | |||
| auto msg = GetMutableTensorRowHeaderMsg(fbb_->GetBufferPointer()); | |||
| auto success = msg->mutate_size_of_this(fbb_->GetSize()); | |||
| if (!success) { | |||
| RETURN_STATUS_UNEXPECTED("Unable to set size_of_this"); | |||
| if (copy_rc.IsError()) { | |||
| // We need to return the memory back to the server | |||
| auto mfree_req = GenerateFreeBlockRequest(); | |||
| Status rc = cc->PushRequest(mfree_req); | |||
| // But we won't wait for the result for the sake of performance. | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "Push request for free memory failed."; | |||
| } | |||
| return copy_rc; | |||
| } | |||
| return Status::OK(); | |||
| } catch (const std::bad_alloc &e) { | |||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); | |||
| } else { | |||
| // We have already filled the first buffer which is the cookie. | |||
| sz_ += rq_.buf_data(0).size(); | |||
| rq_.add_buf_data(fbb->GetBufferPointer(), fbb->GetSize()); | |||
| for (const auto &ts : row) { | |||
| rq_.add_buf_data(ts->GetBuffer(), ts->SizeInBytes()); | |||
| } | |||
| MS_LOG(DEBUG) << "Sending " << sz_ << " bytes of tensor data in " << rq_.buf_data_size() << " segments"; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheRowRequest::SerializeOneTensorMeta(const std::shared_ptr<Tensor> &ts_ptr, | |||
| flatbuffers::Offset<TensorMetaMsg> *out_off) { | |||
| RETURN_UNEXPECTED_IF_NULL(out_off); | |||
| const Tensor *ts = ts_ptr.get(); | |||
| auto shape_off = fbb_->CreateVector(ts->shape().AsVector()); | |||
| const auto ptr = ts->GetBuffer(); | |||
| if (ptr == nullptr) { | |||
| RETURN_STATUS_UNEXPECTED("Tensor buffer is null"); | |||
| } | |||
| auto src = ts->type().value(); | |||
| TensorType dest; | |||
| #define CASE(t) \ | |||
| case DataType::t: \ | |||
| dest = TensorType::TensorType_##t; \ | |||
| break | |||
| // Map the type to fill in the flat buffer. | |||
| switch (src) { | |||
| CASE(DE_BOOL); | |||
| CASE(DE_INT8); | |||
| CASE(DE_UINT8); | |||
| CASE(DE_INT16); | |||
| CASE(DE_UINT16); | |||
| CASE(DE_INT32); | |||
| CASE(DE_UINT32); | |||
| CASE(DE_INT64); | |||
| CASE(DE_UINT64); | |||
| CASE(DE_FLOAT16); | |||
| CASE(DE_FLOAT32); | |||
| CASE(DE_FLOAT64); | |||
| CASE(DE_STRING); | |||
| default: | |||
| MS_LOG(ERROR) << "Unknown tensor. Dumping content:\n" << *ts; | |||
| RETURN_STATUS_UNEXPECTED("Unknown type"); | |||
| Status CacheRowRequest::PostReply() { | |||
| if (!reply_.result().empty()) { | |||
| row_id_from_server_ = strtoll(reply_.result().data(), nullptr, 10); | |||
| } | |||
| #undef CASE | |||
| TensorMetaMsgBuilder ts_builder(*fbb_); | |||
| ts_builder.add_dims(shape_off); | |||
| ts_builder.add_type(dest); | |||
| auto ts_off = ts_builder.Finish(); | |||
| *out_off = ts_off; | |||
| return Status::OK(); | |||
| } | |||
| Status BatchFetchRequest::RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, | |||
| std::shared_ptr<Tensor> *out) { | |||
| RETURN_UNEXPECTED_IF_NULL(col_ts); | |||
| auto shape_in = col_ts->dims(); | |||
| auto type_in = col_ts->type(); | |||
| std::vector<dsize_t> v; | |||
| v.reserve(shape_in->size()); | |||
| v.assign(shape_in->begin(), shape_in->end()); | |||
| TensorShape shape(v); | |||
| DataType::Type dest = DataType::DE_UNKNOWN; | |||
| #define CASE(t) \ | |||
| case TensorType_##t: \ | |||
| dest = DataType::Type::t; \ | |||
| break | |||
| switch (type_in) { | |||
| CASE(DE_BOOL); | |||
| CASE(DE_INT8); | |||
| CASE(DE_UINT8); | |||
| CASE(DE_INT16); | |||
| CASE(DE_UINT16); | |||
| CASE(DE_INT32); | |||
| CASE(DE_UINT32); | |||
| CASE(DE_INT64); | |||
| CASE(DE_UINT64); | |||
| CASE(DE_FLOAT16); | |||
| CASE(DE_FLOAT32); | |||
| CASE(DE_FLOAT64); | |||
| CASE(DE_STRING); | |||
| Status CacheRowRequest::Prepare() { | |||
| if (BitTest(rq_.flag(), kDataIsInSharedMemory)) { | |||
| // First one is cookie, followed by address and then size. | |||
| CHECK_FAIL_RETURN_UNEXPECTED(rq_.buf_data_size() == 3, "Incomplete rpc data"); | |||
| } else { | |||
| // First one is cookie. 2nd one is the google flat buffers followed by a number of buffers. | |||
| // But we are not going to decode them to verify. | |||
| CHECK_FAIL_RETURN_UNEXPECTED(rq_.buf_data_size() >= 3, "Incomplete rpc data"); | |||
| } | |||
| #undef CASE | |||
| DataType type(dest); | |||
| std::shared_ptr<Tensor> ts; | |||
| RETURN_IF_NOT_OK( | |||
| Tensor::CreateFromMemory(shape, type, static_cast<const unsigned char *>(data.GetPointer()), data.GetSize(), &ts)); | |||
| // Next we restore the real data which can be embedded or stored separately. | |||
| if (ts->SizeInBytes() != data.GetSize()) { | |||
| MS_LOG(ERROR) << "Unexpected length. Read " << data.GetSize() << ". Expected " << ts->SizeInBytes() << ".\n" | |||
| << "Dumping tensor\n" | |||
| << *ts << "\n"; | |||
| RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details."); | |||
| } | |||
| *out = std::move(ts); | |||
| return Status::OK(); | |||
| } | |||
| Status BatchFetchRequest::RestoreRows(TensorTable *out) { | |||
| BatchFetchRequest::BatchFetchRequest(connection_id_type connection_id, const std::vector<row_id_type> &row_id, | |||
| bool local_bypass) | |||
| : BaseRequest(RequestType::kBatchFetchRows), support_local_bypass_(local_bypass), row_id_(row_id) { | |||
| rq_.set_connection_id(connection_id); | |||
| rq_.set_flag(support_local_bypass_ ? kLocalClientSupport : 0); | |||
| // Convert the row id into a flatbuffer | |||
| flatbuffers::FlatBufferBuilder fbb; | |||
| auto off_t = fbb.CreateVector(row_id); | |||
| TensorRowIdsBuilder bld(fbb); | |||
| bld.add_row_id(off_t); | |||
| auto off = bld.Finish(); | |||
| fbb.Finish(off); | |||
| rq_.add_buf_data(fbb.GetBufferPointer(), fbb.GetSize()); | |||
| } | |||
| Status BatchFetchRequest::RestoreRows(TensorTable *out, const void *baseAddr, int64_t *out_addr) { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| auto num_elements = row_id_.size(); | |||
| auto *offset_array = reinterpret_cast<const int64_t *>(mem_.GetPointer()); | |||
| const char *ptr = nullptr; | |||
| int64_t sz = 0; | |||
| // Tap into the reply flag to see where we can find the data. Server may decide the amount is | |||
| // so small that it doesn't use shared memory method. | |||
| auto flag = reply_.flag(); | |||
| bool dataOnSharedMemory = support_local_bypass_ ? (BitTest(flag, kDataIsInSharedMemory)) : false; | |||
| if (dataOnSharedMemory) { | |||
| auto addr = strtoll(reply_.result().data(), nullptr, 10); | |||
| ptr = reinterpret_cast<const char *>(reinterpret_cast<int64_t>(baseAddr) + addr); | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| *out_addr = addr; | |||
| } else { | |||
| ptr = reply_.result().data(); | |||
| *out_addr = -1; | |||
| } | |||
| auto *offset_array = reinterpret_cast<const int64_t *>(ptr); | |||
| sz = offset_array[num_elements]; | |||
| CHECK_FAIL_RETURN_UNEXPECTED(support_local_bypass_ || sz == reply_.result().length(), "Length mismatch"); | |||
| TensorTable tbl; | |||
| tbl.reserve(num_elements); | |||
| ReadableSlice all(mem_.GetPointer(), mem_.GetSizeInBytes()); | |||
| ReadableSlice all(ptr, sz); | |||
| for (auto i = 0; i < num_elements; ++i) { | |||
| auto len = offset_array[i + 1] - offset_array[i]; | |||
| TensorRow row; | |||
| @@ -178,10 +173,12 @@ Status BatchFetchRequest::RestoreRows(TensorTable *out) { | |||
| auto col_ts = msg->column()->Get(k); | |||
| std::shared_ptr<Tensor> ts; | |||
| ReadableSlice data(row_data, ts_offset, msg->data_sz()->Get(k)); | |||
| RETURN_IF_NOT_OK(RestoreOneTensor(col_ts, data, &ts)); | |||
| RETURN_IF_NOT_OK(mindspore::dataset::RestoreOneTensor(col_ts, data, &ts)); | |||
| row.push_back(ts); | |||
| ts_offset += data.GetSize(); | |||
| } | |||
| } else { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(len == 0, "Data corruption detected."); | |||
| } | |||
| tbl.push_back(std::move(row)); | |||
| } | |||
| @@ -189,36 +186,69 @@ Status BatchFetchRequest::RestoreRows(TensorTable *out) { | |||
| return Status::OK(); | |||
| } | |||
| CreateCacheRequest::CreateCacheRequest(const CacheClientInfo &cinfo, uint64_t cache_mem_sz, | |||
| CreateCacheRequest::CreateCacheFlag flag) | |||
| : BaseRequest(RequestType::kCreateCache), cache_mem_sz_(cache_mem_sz), flag_(flag) { | |||
| // Type has been set already in the base constructor. So we need to fill in the connection info. | |||
| // On successful return, we will get the connection id | |||
| rq_.mutable_connection_info()->operator=(cinfo); | |||
| } | |||
| Status CreateCacheRequest::Prepare() { | |||
| try { | |||
| flatbuffers::FlatBufferBuilder fbb; | |||
| CreateCacheRequestMsgBuilder bld(fbb); | |||
| bld.add_cache_mem_sz(cache_mem_sz_); | |||
| bld.add_flag(static_cast<uint32_t>(flag_)); | |||
| auto off = bld.Finish(); | |||
| fbb.Finish(off); | |||
| rq_.add_buf_data(fbb.GetBufferPointer(), fbb.GetSize()); | |||
| return Status::OK(); | |||
| } catch (const std::bad_alloc &e) { | |||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); | |||
| } | |||
| } | |||
| Status CacheSchemaRequest::SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map) { | |||
| try { | |||
| fbb_ = std::make_shared<flatbuffers::FlatBufferBuilder>(); | |||
| flatbuffers::FlatBufferBuilder fbb; | |||
| std::vector<flatbuffers::Offset<ColumnNameMsg>> v; | |||
| v.reserve(map.size()); | |||
| for (auto &column : map) { | |||
| auto c = CreateColumnNameMsg(*fbb_, fbb_->CreateString(column.first), column.second); | |||
| auto c = CreateColumnNameMsg(fbb, fbb.CreateString(column.first), column.second); | |||
| v.push_back(c); | |||
| } | |||
| auto v_off = fbb_->CreateVector(v); | |||
| auto final_off = CreateSchemaMsg(*fbb_, v_off); | |||
| fbb_->Finish(final_off); | |||
| buf_ = fbb_->GetBufferPointer(); | |||
| len_of_buf_ = fbb_->GetSize(); | |||
| auto v_off = fbb.CreateVector(v); | |||
| auto final_off = CreateSchemaMsg(fbb, v_off); | |||
| fbb.Finish(final_off); | |||
| rq_.add_buf_data(fbb.GetBufferPointer(), fbb.GetSize()); | |||
| return Status::OK(); | |||
| } catch (const std::bad_alloc &e) { | |||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); | |||
| } | |||
| } | |||
| std::unordered_map<std::string, int32_t> FetchSchemaRequest::GetColumnMap() { | |||
| if (column_name_id_map_.empty()) { | |||
| auto *map_msg = flatbuffers::GetRoot<SchemaMsg>(mem_.GetPointer()); | |||
| auto v = map_msg->column(); | |||
| for (auto i = 0; i < v->size(); ++i) { | |||
| auto col = map_msg->column()->Get(i); | |||
| column_name_id_map_.emplace(col->name()->str(), col->id()); | |||
| } | |||
| Status FetchSchemaRequest::PostReply() { | |||
| auto *map_msg = flatbuffers::GetRoot<SchemaMsg>(reply_.result().data()); | |||
| auto v = map_msg->column(); | |||
| for (auto i = 0; i < v->size(); ++i) { | |||
| auto col = map_msg->column()->Get(i); | |||
| column_name_id_map_.emplace(col->name()->str(), col->id()); | |||
| } | |||
| return column_name_id_map_; | |||
| return Status::OK(); | |||
| } | |||
| std::unordered_map<std::string, int32_t> FetchSchemaRequest::GetColumnMap() { return column_name_id_map_; } | |||
| Status GetStatRequest::PostReply() { | |||
| auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(reply_.result().data()); | |||
| stat_.num_disk_cached = msg->num_disk_cached(); | |||
| stat_.num_mem_cached = msg->num_mem_cached(); | |||
| stat_.avg_cache_sz = msg->avg_cache_sz(); | |||
| stat_.max_row_id = msg->max_row_id(); | |||
| stat_.min_row_id = msg->min_row_id(); | |||
| stat_.cache_service_state = msg->state(); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -18,11 +18,16 @@ | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <iostream> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <vector> | |||
| #ifdef ENABLE_CACHE | |||
| #include "proto/cache_grpc.grpc.pb.h" | |||
| #endif | |||
| #include "proto/cache_grpc.pb.h" | |||
| #include "minddata/dataset/core/tensor_row.h" | |||
| #include "minddata/dataset/engine/cache/de_tensor_generated.h" | |||
| #include "minddata/dataset/util/slice.h" | |||
| @@ -30,6 +35,17 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class CacheClient; | |||
| /// \brief Statistic structure for GetStat request | |||
| struct CacheServiceStat { | |||
| int64_t num_mem_cached; | |||
| int64_t num_disk_cached; | |||
| int64_t avg_cache_sz; | |||
| row_id_type min_row_id; | |||
| row_id_type max_row_id; | |||
| int8_t cache_service_state; | |||
| }; | |||
| /// \brief CacheClient communicates with CacheServer using Requests. | |||
| class BaseRequest { | |||
| public: | |||
| @@ -44,195 +60,301 @@ class BaseRequest { | |||
| kCacheSchema = 6, | |||
| kFetchSchema = 7, | |||
| kBuildPhaseDone = 8, | |||
| kDropSession = 9, | |||
| kGenerateSessionId = 10, | |||
| kAllocateSharedBlock = 11, | |||
| kFreeSharedBlock = 12, | |||
| kStopService = 13, | |||
| // Add new request before it. | |||
| kRequestUnknown = 32767 | |||
| }; | |||
| // For kCreateCache | |||
| enum class CreateCacheFlag : uint32_t { kNone = 0, kSpillToDisk = 1, kGenerateRowId = 1u << 1L }; | |||
| friend class CacheServer; | |||
| friend class CacheServerRequest; | |||
| friend class CacheClientGreeter; | |||
| friend class CacheClientRequestTag; | |||
| /// \brief Base class of a cache server request | |||
| /// \param connection_id A combination of session id and crc that uniquely identifies a connection. | |||
| /// \param type Type of the request | |||
| explicit BaseRequest(connection_id_type connection_id, RequestType type) | |||
| : type_(type), connection_id_(connection_id) {} | |||
| explicit BaseRequest(RequestType type) : type_(type) { rq_.set_type(static_cast<google::int32>(type_)); } | |||
| virtual ~BaseRequest() = default; | |||
| /// \brief Wait for the completion of a request | |||
| /// \return Status returned from the cache server | |||
| Status Wait() { | |||
| RETURN_IF_NOT_OK(wp_.Wait()); | |||
| return rc_; | |||
| /// \brief A print method for debugging | |||
| /// \param out The output stream to write output to | |||
| virtual void Print(std::ostream &out) const { out << "Request type: " << static_cast<int16_t>(type_); } | |||
| /// \brief << Stream output operator overload | |||
| /// \param out reference to the output stream | |||
| /// \param rq reference to the BaseRequest | |||
| /// \return the output stream | |||
| friend std::ostream &operator<<(std::ostream &out, const BaseRequest &rq) { | |||
| rq.Print(out); | |||
| return out; | |||
| } | |||
| /// \brief Getter function of the current connection id | |||
| /// \return Connection id | |||
| connection_id_type GetServerConnectionId() const { return connection_id_; } | |||
| /// \brief Derived class can implement extra work to be done before the request is sent to the server | |||
| virtual Status Prepare() { return Status::OK(); } | |||
| /// \brief Derived class can implement extra work to be done after the server sends the request | |||
| virtual Status PostReply() { return Status::OK(); } | |||
| /// \brief A method for the client to wait for the availability of the result back from the server. | |||
| /// \return Status object | |||
| Status Wait(); | |||
| protected: | |||
| CacheRequest rq_; // This is what we send to the server | |||
| CacheReply reply_; // This is what the server send back | |||
| private: | |||
| RequestType type_; | |||
| connection_id_type connection_id_; | |||
| Status rc_; | |||
| WaitPost wp_; | |||
| WaitPost wp_; // A sync area used by the client side. | |||
| }; | |||
| class FreeSharedBlockRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| explicit FreeSharedBlockRequest(connection_id_type connection_id, int64_t addr) | |||
| : BaseRequest(RequestType::kFreeSharedBlock) { | |||
| rq_.set_connection_id(connection_id); | |||
| rq_.add_buf_data(std::to_string(addr)); | |||
| } | |||
| ~FreeSharedBlockRequest() = default; | |||
| }; | |||
| /// \brief Request to cache a single TensorRow | |||
| class CacheRowRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| explicit CacheRowRequest(connection_id_type connection_id, const std::string &cookie) | |||
| : BaseRequest(connection_id, RequestType::kCacheRow), row_id_from_server_(-1), cookie_(cookie) {} | |||
| friend class CacheClient; | |||
| explicit CacheRowRequest(connection_id_type connection_id, const std::string &cookie, bool local_bypass) | |||
| : BaseRequest(RequestType::kCacheRow), | |||
| support_local_bypass_(local_bypass), | |||
| addr_(-1), | |||
| sz_(0), | |||
| row_id_from_server_(-1) { | |||
| rq_.set_connection_id(connection_id); | |||
| rq_.add_buf_data(cookie); | |||
| } | |||
| ~CacheRowRequest() = default; | |||
| /// \brief Serialize a TensorRow for streaming to the cache server | |||
| /// \param row TensorRow | |||
| /// \return Status object | |||
| Status SerializeCacheRowRequest(const TensorRow &row); | |||
| Status SerializeCacheRowRequest(const CacheClient *cc, const TensorRow &row); | |||
| /// \brief Sanity check before we send the row. | |||
| /// \return Status object | |||
| Status Prepare() override; | |||
| /// \brief Override the base function get the row id returned from the server | |||
| /// \return Status object | |||
| Status PostReply() override; | |||
| /// \brief Return the row id assigned to this row for non-mappable dataset | |||
| /// \return row id of the cached row | |||
| row_id_type GetRowIdAfterCache() { return row_id_from_server_; } | |||
| /// \brief If we are doing local bypass, fill in extra request information of where the data is located. | |||
| void AddDataLocation() { | |||
| if (support_local_bypass_) { | |||
| rq_.add_buf_data(std::to_string(addr_)); | |||
| rq_.add_buf_data(std::to_string(sz_)); | |||
| } | |||
| } | |||
| /// \brief If we fail to send the data to the server using shared memory method, we should release | |||
| /// the shared memory by sending another request. The following function will generate a suitable | |||
| /// request for the CacheClient to send. | |||
| std::shared_ptr<FreeSharedBlockRequest> GenerateFreeBlockRequest() { | |||
| return std::make_shared<FreeSharedBlockRequest>(rq_.connection_id(), addr_); | |||
| } | |||
| private: | |||
| std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb_; | |||
| bool support_local_bypass_; | |||
| int64_t addr_; | |||
| int64_t sz_; | |||
| row_id_type row_id_from_server_; | |||
| std::vector<const void *> buffers_; | |||
| std::string cookie_; | |||
| /// \brief Private function to serialize one TensorRow | |||
| /// \param row TensorRow | |||
| /// \return Status object | |||
| Status SerializeTensorRowHeader(const TensorRow &row); | |||
| /// \brief Private function to serialize one Tensor | |||
| /// \param ts_ptr Tensor | |||
| /// \return Status object | |||
| Status SerializeOneTensorMeta(const std::shared_ptr<Tensor> &ts_ptr, flatbuffers::Offset<TensorMetaMsg> *out_off); | |||
| }; | |||
| /// \brief Request to fetch rows in batch | |||
| class BatchFetchRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| friend class CacheService; | |||
| BatchFetchRequest(connection_id_type connection_id, const std::vector<row_id_type> &row_id) | |||
| : BaseRequest(connection_id, RequestType::kBatchFetchRows), row_id_(row_id) {} | |||
| BatchFetchRequest(connection_id_type connection_id, const std::vector<row_id_type> &row_id, bool local_bypass); | |||
| ~BatchFetchRequest() = default; | |||
| Status RestoreRows(TensorTable *out); | |||
| Status RestoreRows(TensorTable *out, const void *baseAddr, int64_t *out_addr); | |||
| private: | |||
| bool support_local_bypass_; | |||
| std::vector<row_id_type> row_id_; | |||
| MemGuard<uint8_t> mem_; | |||
| Status RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, std::shared_ptr<Tensor> *out); | |||
| }; | |||
| /// \brief Request to create a cache for the current connection | |||
| class CreationCacheRequest : public BaseRequest { | |||
| class CreateCacheRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| enum class CreateCacheFlag : uint32_t { kNone = 0, kSpillToDisk = 1, kGenerateRowId = 1u << 1L }; | |||
| /// \brief Constructor | |||
| /// \param connection_id | |||
| /// \param cache_mem_sz Maximum memory assigned for this connection. 0 means unlimited | |||
| /// \param flag Attributes of the cache. | |||
| explicit CreationCacheRequest(connection_id_type connection_id, uint64_t cache_mem_sz, | |||
| CreateCacheFlag flag = CreateCacheFlag::kNone) | |||
| : BaseRequest(connection_id, RequestType::kCreateCache), cache_mem_sz(cache_mem_sz), flag_(flag) {} | |||
| ~CreationCacheRequest() = default; | |||
| explicit CreateCacheRequest(const CacheClientInfo &cinfo, uint64_t cache_mem_sz, | |||
| CreateCacheFlag flag = CreateCacheFlag::kNone); | |||
| ~CreateCacheRequest() = default; | |||
| void ParseResult(connection_id_type *id, std::string *out) { | |||
| auto p = flatbuffers::GetRoot<CreateCacheReplyMsg>(reply_.result().data()); | |||
| *id = p->connection_id(); | |||
| *out = p->cookie()->str(); | |||
| } | |||
| std::string cookie() const { return cookie_; } | |||
| /// Overload the base class Prepare | |||
| Status Prepare() override; | |||
| private: | |||
| uint64_t cache_mem_sz; | |||
| uint64_t cache_mem_sz_; | |||
| CreateCacheFlag flag_; | |||
| std::string cookie_; | |||
| }; | |||
| /// \brief Request to purge a cache. | |||
| class PurgeCacheRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| explicit PurgeCacheRequest(connection_id_type connection_id) : BaseRequest(connection_id, RequestType::kPurgeCache) {} | |||
| explicit PurgeCacheRequest(connection_id_type connection_id) : BaseRequest(RequestType::kPurgeCache) { | |||
| rq_.set_connection_id(connection_id); | |||
| } | |||
| ~PurgeCacheRequest() = default; | |||
| }; | |||
| /// \brief Request to destroy a cache | |||
| class DestroyCacheRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| explicit DestroyCacheRequest(connection_id_type connection_id) | |||
| : BaseRequest(connection_id, RequestType::kDestroyCache) {} | |||
| /// \brief Destructor | |||
| explicit DestroyCacheRequest(connection_id_type connection_id) : BaseRequest(RequestType::kDestroyCache) { | |||
| rq_.set_connection_id(connection_id); | |||
| } | |||
| ~DestroyCacheRequest() = default; | |||
| }; | |||
| /// \brief Obtain the statistics of the current connection | |||
| class GetStatRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| friend class CacheService; | |||
| explicit GetStatRequest(connection_id_type connection_id) : BaseRequest(connection_id, RequestType::kGetStat) {} | |||
| explicit GetStatRequest(connection_id_type connection_id) : BaseRequest(RequestType::kGetStat) { | |||
| rq_.set_connection_id(connection_id); | |||
| } | |||
| ~GetStatRequest() = default; | |||
| row_id_type GetMinRowId() const { | |||
| auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer()); | |||
| return msg->min_row_id(); | |||
| } | |||
| row_id_type GetMaxRowId() const { | |||
| auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer()); | |||
| return msg->max_row_id(); | |||
| } | |||
| int64_t GetNumMemCached() const { | |||
| auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer()); | |||
| return msg->num_mem_cached(); | |||
| } | |||
| int64_t GetNumDiskCached() const { | |||
| auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer()); | |||
| return msg->num_disk_cached(); | |||
| } | |||
| uint8_t GetState() const { | |||
| auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer()); | |||
| return msg->state(); | |||
| /// \brief Override base function to process the result. | |||
| Status PostReply() override; | |||
| void GetStat(CacheServiceStat *stat) { | |||
| if (stat != nullptr) { | |||
| (*stat) = stat_; | |||
| } | |||
| } | |||
| private: | |||
| MemGuard<uint8_t> mem_; | |||
| CacheServiceStat stat_{}; | |||
| }; | |||
| /// \brief Request to cache a schema | |||
| class CacheSchemaRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| explicit CacheSchemaRequest(connection_id_type connection_id) | |||
| : BaseRequest(connection_id, RequestType::kCacheSchema), buf_(nullptr), len_of_buf_(0) {} | |||
| explicit CacheSchemaRequest(connection_id_type connection_id) : BaseRequest(RequestType::kCacheSchema) { | |||
| rq_.set_connection_id(connection_id); | |||
| } | |||
| ~CacheSchemaRequest() = default; | |||
| Status SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map); | |||
| const void *GetBuffer() const { return buf_; } | |||
| private: | |||
| std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb_; | |||
| const void *buf_; | |||
| int64_t len_of_buf_; | |||
| }; | |||
| /// \brief Request to fetch a schema | |||
| class FetchSchemaRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| explicit FetchSchemaRequest(connection_id_type connection_id) | |||
| : BaseRequest(connection_id, RequestType::kFetchSchema) {} | |||
| explicit FetchSchemaRequest(connection_id_type connection_id) : BaseRequest(RequestType::kFetchSchema) { | |||
| rq_.set_connection_id(connection_id); | |||
| } | |||
| ~FetchSchemaRequest() = default; | |||
| Status PostReply() override; | |||
| std::unordered_map<std::string, int32_t> GetColumnMap(); | |||
| private: | |||
| MemGuard<uint8_t> mem_; | |||
| std::unordered_map<std::string, int32_t> column_name_id_map_; | |||
| }; | |||
| /// \brief Request to change a cache from build phase to read phase. Applies to non-mappable cache only. | |||
| class BuildPhaseDoneRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| BuildPhaseDoneRequest(connection_id_type connection_id, const std::string &cookie) | |||
| : BaseRequest(connection_id, RequestType::kBuildPhaseDone), cookie_(cookie) {} | |||
| : BaseRequest(RequestType::kBuildPhaseDone), cookie_(cookie) { | |||
| rq_.set_connection_id(connection_id); | |||
| rq_.add_buf_data(cookie_); | |||
| } | |||
| ~BuildPhaseDoneRequest() = default; | |||
| private: | |||
| std::string cookie_; | |||
| }; | |||
| /// \brief Request to drop all the caches in the current session | |||
| class DropSessionRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| explicit DropSessionRequest(const CacheClientInfo &cinfo) : BaseRequest(RequestType::kDropSession) { | |||
| rq_.mutable_connection_info()->operator=(cinfo); | |||
| } | |||
| ~DropSessionRequest() = default; | |||
| }; | |||
| class GenerateSessionIdRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| GenerateSessionIdRequest() : BaseRequest(RequestType::kGenerateSessionId) { | |||
| // We don't have anything client info nor connection id to send. But we will manually | |||
| // set the connection id to 0. | |||
| rq_.set_connection_id(0); | |||
| } | |||
| ~GenerateSessionIdRequest() = default; | |||
| session_id_type GetSessionId() { return atoi(reply_.result().data()); } | |||
| }; | |||
| class AllocateSharedBlockRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| explicit AllocateSharedBlockRequest(connection_id_type connection_id, size_t requestedSz) | |||
| : BaseRequest(RequestType::kAllocateSharedBlock) { | |||
| rq_.set_connection_id(connection_id); | |||
| rq_.add_buf_data(std::to_string(requestedSz)); | |||
| } | |||
| ~AllocateSharedBlockRequest() = default; | |||
| /// \brief On return from the server, we get the (relative) address where | |||
| /// the free block is located. | |||
| /// \return | |||
| int64_t GetAddr() { | |||
| auto addr = strtoll(reply_.result().data(), nullptr, 10); | |||
| return addr; | |||
| } | |||
| }; | |||
| class ShutdownRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| ShutdownRequest() : BaseRequest(RequestType::kStopService) {} | |||
| ~ShutdownRequest() = default; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVICE_H_ | |||
| @@ -14,25 +14,89 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/cache/cache_server.h" | |||
| #include <algorithm> | |||
| #include <functional> | |||
| #include <limits> | |||
| #include "minddata/dataset/core/constants.h" | |||
| #include "minddata/dataset/engine/cache/cache_service.h" | |||
| #include "minddata/dataset/engine/cache/cache_request.h" | |||
| #include "minddata/dataset/util/bit.h" | |||
| #include "minddata/dataset/util/path.h" | |||
| #include "minddata/dataset/util/random.h" | |||
| #ifdef CACHE_LOCAL_CLIENT | |||
| #include "minddata/dataset/util/sig_handler.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| CacheServer *CacheServer::instance_ = nullptr; | |||
| std::once_flag CacheServer::init_instance_flag_; | |||
| Status CacheServer::DoServiceStart() { | |||
| #ifdef CACHE_LOCAL_CLIENT | |||
| // We need to destroy the shared memory if user hits Control-C | |||
| RegisterHandlers(); | |||
| #endif | |||
| if (!top_.empty()) { | |||
| Path spill(top_); | |||
| RETURN_IF_NOT_OK(spill.CreateDirectories()); | |||
| MS_LOG(INFO) << "CacheServer will use disk folder: " << top_; | |||
| } | |||
| RETURN_IF_NOT_OK(vg_.ServiceStart()); | |||
| cache_q_ = std::make_shared<Queue<BaseRequest *>>(1024); | |||
| // There will be num_workers_ threads working on the grpc queue and | |||
| // the same number of threads working on the CacheServerRequest queue. | |||
| // Like a connector object we will set up the same number of queues but | |||
| // we do not need to preserve any order. We will set the capacity of | |||
| // each queue to be 128 since we are just pushing memory pointers which | |||
| // is only 8 byte each. | |||
| const int32_t que_capacity = 128; | |||
| // This is the request queue from the client | |||
| cache_q_ = std::make_shared<QueueList<CacheServerRequest *>>(); | |||
| cache_q_->Init(num_workers_, que_capacity); | |||
| // For the grpc completion queue to work, we need to allocate some | |||
| // tags which in our case are instances of CacheServerQuest. | |||
| // They got recycled and we will allocate them in advance and push | |||
| // them into some free list. We need more (two or three times) the | |||
| // size of the cache_q. While each worker is working on a CacheSerRequest, | |||
| // we need some extra running injecting in the the qrpc completion queue. | |||
| const int32_t multiplier = 3; | |||
| const int32_t free_list_capacity = multiplier * (que_capacity + 1); | |||
| free_list_ = std::make_shared<QueueList<CacheServerRequest *>>(); | |||
| free_list_->Init(num_workers_, free_list_capacity); | |||
| // We need to have a reference to the services memory pool in case | |||
| // the Services goes out of scope earlier than us since it is a singleton | |||
| mp_ = Services::GetInstance().GetServiceMemPool(); | |||
| Allocator<CacheServerRequest> alloc(mp_); | |||
| tag_.reserve(num_workers_); | |||
| // Now we populate all free list. | |||
| for (auto m = 0; m < num_workers_; ++m) { | |||
| // Ideally we allocate all the free list in one malloc. But it turns out it exceeds the | |||
| // Arena size. So we will we will allocate one segment at a time. | |||
| auto my_tag = std::make_unique<MemGuard<CacheServerRequest, Allocator<CacheServerRequest>>>(alloc); | |||
| // Allocate the tag and assign it the current queue | |||
| RETURN_IF_NOT_OK(my_tag->allocate(free_list_capacity, m)); | |||
| for (int i = 0; i < free_list_capacity; ++i) { | |||
| RETURN_IF_NOT_OK(free_list_->operator[](m)->Add((*my_tag)[i])); | |||
| } | |||
| tag_.push_back(std::move(my_tag)); | |||
| } | |||
| RETURN_IF_NOT_OK(cache_q_->Register(&vg_)); | |||
| auto f = std::bind(&CacheServer::ServerRequest, this); | |||
| // Spawn a a few threads to serve the request. | |||
| RETURN_IF_NOT_OK(free_list_->Register(&vg_)); | |||
| // Spawn a few threads to serve the real request. | |||
| auto f = std::bind(&CacheServer::ServerRequest, this, std::placeholders::_1); | |||
| for (auto i = 0; i < num_workers_; ++i) { | |||
| RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Cache service worker", std::bind(f, i))); | |||
| } | |||
| // Start the comm layer | |||
| try { | |||
| comm_layer_ = std::make_shared<CacheServerGreeterImpl>(port_, shared_memory_sz_in_gb_); | |||
| RETURN_IF_NOT_OK(comm_layer_->Run()); | |||
| } catch (const std::exception &e) { | |||
| RETURN_STATUS_UNEXPECTED(e.what()); | |||
| } | |||
| // Finally loop forever to handle the request. | |||
| auto r = std::bind(&CacheServer::RpcRequest, this, std::placeholders::_1); | |||
| for (auto i = 0; i < num_workers_; ++i) { | |||
| RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Cache server", f)); | |||
| RETURN_IF_NOT_OK(vg_.CreateAsyncTask("rpc worker", std::bind(r, i))); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -65,188 +129,551 @@ CacheService *CacheServer::GetService(connection_id_type id) const { | |||
| return nullptr; | |||
| } | |||
| Status CacheServer::CreateService(connection_id_type connection_id, uint64_t cache_mem_sz, | |||
| BaseRequest::CreateCacheFlag flag, std::string *out_cookie) { | |||
| Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing connection info"); | |||
| std::string cookie; | |||
| auto session_id = rq->connection_info().session_id(); | |||
| auto crc = rq->connection_info().crc(); | |||
| // We concat both numbers to form the internal connection id. | |||
| auto connection_id = GetConnectionID(session_id, crc); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing info to create cache"); | |||
| auto &create_cache_buf = rq->buf_data(0); | |||
| auto p = flatbuffers::GetRoot<CreateCacheRequestMsg>(create_cache_buf.data()); | |||
| auto flag = static_cast<CreateCacheRequest::CreateCacheFlag>(p->flag()); | |||
| auto cache_mem_sz = p->cache_mem_sz(); | |||
| // We can't do spilling unless this server is setup with a spill path in the first place | |||
| bool spill = (flag & BaseRequest::CreateCacheFlag::kSpillToDisk) == BaseRequest::CreateCacheFlag::kSpillToDisk; | |||
| bool spill = | |||
| (flag & CreateCacheRequest::CreateCacheFlag::kSpillToDisk) == CreateCacheRequest::CreateCacheFlag::kSpillToDisk; | |||
| bool generate_id = | |||
| (flag & BaseRequest::CreateCacheFlag::kGenerateRowId) == BaseRequest::CreateCacheFlag::kGenerateRowId; | |||
| (flag & CreateCacheRequest::CreateCacheFlag::kGenerateRowId) == CreateCacheRequest::CreateCacheFlag::kGenerateRowId; | |||
| if (spill && top_.empty()) { | |||
| RETURN_STATUS_UNEXPECTED("Server is not set up with spill support."); | |||
| } | |||
| RETURN_UNEXPECTED_IF_NULL(out_cookie); | |||
| *out_cookie = ""; | |||
| flatbuffers::FlatBufferBuilder fbb; | |||
| flatbuffers::Offset<flatbuffers::String> off_cookie; | |||
| // Before creating the cache, first check if this is a request for a shared usage of an existing cache | |||
| // If two CreateService come in with identical connection_id, we need to serialize the create. | |||
| // The first create will be successful and be given a special cookie. | |||
| UniqueLock lck(&rwLock_); | |||
| // Early exit if we are doing global shutdown | |||
| if (global_shutdown_) { | |||
| return Status::OK(); | |||
| } | |||
| auto end = all_caches_.end(); | |||
| auto it = all_caches_.find(connection_id); | |||
| bool duplicate = false; | |||
| if (it == end) { | |||
| std::unique_ptr<CacheService> cs; | |||
| try { | |||
| cs = std::make_unique<CacheService>(cache_mem_sz, spill ? top_ : "", generate_id); | |||
| RETURN_IF_NOT_OK(cs->ServiceStart()); | |||
| *out_cookie = cs->cookie(); | |||
| cookie = cs->cookie(); | |||
| all_caches_.emplace(connection_id, std::move(cs)); | |||
| } catch (const std::bad_alloc &e) { | |||
| return Status(StatusCode::kOutOfMemory); | |||
| } | |||
| } else { | |||
| duplicate = true; | |||
| MS_LOG(INFO) << "Duplicate request for " + std::to_string(connection_id) + " to create cache service"; | |||
| // We can return OK but we will return a duplicate key so user can act accordingly to either ignore it | |||
| // treat it as OK. | |||
| return Status(StatusCode::kDuplicateKey); | |||
| } | |||
| off_cookie = fbb.CreateString(cookie); | |||
| CreateCacheReplyMsgBuilder bld(fbb); | |||
| bld.add_connection_id(connection_id); | |||
| bld.add_cookie(off_cookie); | |||
| auto off = bld.Finish(); | |||
| fbb.Finish(off); | |||
| reply->set_result(fbb.GetBufferPointer(), fbb.GetSize()); | |||
| // Track the history of all the sessions that we have created so far. | |||
| history_sessions_.insert(session_id); | |||
| // We can return OK but we will return a duplicate key so user can act accordingly to either ignore it | |||
| // treat it as OK. | |||
| return duplicate ? Status(StatusCode::kDuplicateKey) : Status::OK(); | |||
| } | |||
| Status CacheServer::DestroyCache(CacheService *cs, CacheRequest *rq) { | |||
| // We need a strong lock to protect the map. | |||
| UniqueLock lck(&rwLock_); | |||
| // it is already destroyed. Ignore it. | |||
| if (cs != nullptr) { | |||
| auto id = rq->connection_id(); | |||
| MS_LOG(WARNING) << "Dropping cache with connection id " << std::to_string(id); | |||
| // std::map will invoke the destructor of CacheService. So we don't need to do anything here. | |||
| auto n = all_caches_.erase(id); | |||
| if (n == 0) { | |||
| // It has been destroyed by another duplicate request. | |||
| MS_LOG(INFO) << "Duplicate request for " + std::to_string(id) + " to create cache service"; | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| inline Status CacheRow(CacheService *cs, CacheRequest *rq, CacheReply *reply) { | |||
| auto connection_id = rq->connection_id(); | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found"; | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } else { | |||
| auto sz = rq->buf_data_size(); | |||
| std::vector<const void *> buffers; | |||
| buffers.reserve(sz); | |||
| // First piece of data is the cookie and is required | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing cookie"); | |||
| auto &cookie = rq->buf_data(0); | |||
| // Only if the cookie matches, we can accept insert into this cache that has a build phase | |||
| if (!cs->HasBuildPhase() || cookie == cs->cookie()) { | |||
| // Push the address of each buffer (in the form of std::string coming in from protobuf) into | |||
| // a vector of buffer | |||
| for (auto i = 1; i < sz; ++i) { | |||
| buffers.push_back(rq->buf_data(i).data()); | |||
| } | |||
| row_id_type id = -1; | |||
| RETURN_IF_NOT_OK(cs->CacheRow(buffers, &id)); | |||
| reply->set_result(std::to_string(id)); | |||
| } else { | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch"); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| /// This is the main loop the cache server thread(s) are running. | |||
| /// Each thread will pop a request and save the result in the same request. | |||
| /// The sender will wait on the wait post in the request. Once the request | |||
| /// is fulfilled, the server thread will do a post signalling the request is | |||
| /// is processed. | |||
| Status CacheServer::FastCacheRow(CacheService *cs, CacheRequest *rq, CacheReply *reply) { | |||
| auto connection_id = rq->connection_id(); | |||
| auto shared_pool = comm_layer_->GetSharedMemoryPool(); | |||
| auto *base = shared_pool->SharedMemoryBaseAddr(); | |||
| // Ensure we got 3 pieces of data coming in | |||
| CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data_size() == 3, "Incomplete data"); | |||
| // First piece of data is the cookie and is required | |||
| auto &cookie = rq->buf_data(0); | |||
| // Second piece of data is the address where we can find the serialized data | |||
| auto addr = strtoll(rq->buf_data(1).data(), nullptr, 10); | |||
| auto p = reinterpret_cast<void *>(reinterpret_cast<int64_t>(base) + addr); | |||
| // Third piece of data is the size of the serialized data that we need to transfer | |||
| auto sz = strtoll(rq->buf_data(2).data(), nullptr, 10); | |||
| // Successful or not, we need to free the memory on exit. | |||
| Status rc; | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found"; | |||
| rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } else { | |||
| // Only if the cookie matches, we can accept insert into this cache that has a build phase | |||
| if (!cs->HasBuildPhase() || cookie == cs->cookie()) { | |||
| row_id_type id = -1; | |||
| ReadableSlice src(p, sz); | |||
| rc = cs->FastCacheRow(src, &id); | |||
| reply->set_result(std::to_string(id)); | |||
| } else { | |||
| rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch"); | |||
| } | |||
| } | |||
| // Return the block to the shared memory. | |||
| shared_pool->Deallocate(p); | |||
| return rc; | |||
| } | |||
| Status CacheServer::BatchFetchRows(CacheService *cs, CacheRequest *rq, CacheReply *reply) { | |||
| auto connection_id = rq->connection_id(); | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found"; | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } else { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing row id"); | |||
| auto &row_id_buf = rq->buf_data(0); | |||
| auto p = flatbuffers::GetRoot<TensorRowIds>(row_id_buf.data()); | |||
| std::vector<row_id_type> row_id; | |||
| auto sz = p->row_id()->size(); | |||
| row_id.reserve(sz); | |||
| for (auto i = 0; i < sz; ++i) { | |||
| row_id.push_back(p->row_id()->Get(i)); | |||
| } | |||
| int64_t mem_sz = 0; | |||
| std::vector<key_size_pair> v; | |||
| RETURN_IF_NOT_OK(cs->PreBatchFetch(row_id, &v, &mem_sz)); | |||
| auto client_flag = rq->flag(); | |||
| bool local_client = BitTest(client_flag, kLocalClientSupport); | |||
| // For large amount data to be sent back, we will use shared memory provided it is a local | |||
| // client that has local bypass support | |||
| bool local_bypass = local_client ? (mem_sz >= kLocalByPassThreshold) : false; | |||
| reply->set_flag(local_bypass ? kDataIsInSharedMemory : 0); | |||
| if (local_bypass) { | |||
| // We will use shared memory | |||
| auto shared_pool = comm_layer_->GetSharedMemoryPool(); | |||
| auto *base = shared_pool->SharedMemoryBaseAddr(); | |||
| void *q = nullptr; | |||
| RETURN_IF_NOT_OK(shared_pool->Allocate(mem_sz, &q)); | |||
| WritableSlice dest(q, mem_sz); | |||
| RETURN_IF_NOT_OK(cs->BatchFetch(row_id, v, &dest)); | |||
| // We can't return the absolute address which makes no sense to the client. | |||
| // Instead we return the difference. | |||
| auto difference = reinterpret_cast<int64_t>(q) - reinterpret_cast<int64_t>(base); | |||
| reply->set_result(std::to_string(difference)); | |||
| } else { | |||
| // We are going to use std::string to allocate and hold the result which will be eventually | |||
| // 'moved' to the protobuf message (which underneath is also a std::string) for the purpose | |||
| // to minimize memory copy. | |||
| std::string mem; | |||
| try { | |||
| mem.resize(mem_sz); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(mem.capacity() >= mem_sz, "Programming error"); | |||
| } catch (const std::bad_alloc &e) { | |||
| return Status(StatusCode::kOutOfMemory); | |||
| } | |||
| WritableSlice dest(mem.data(), mem_sz); | |||
| RETURN_IF_NOT_OK(cs->BatchFetch(row_id, v, &dest)); | |||
| reply->set_result(std::move(mem)); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| inline Status GetStat(CacheService *cs, CacheRequest *rq, CacheReply *reply) { | |||
| auto connection_id = rq->connection_id(); | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } else { | |||
| CacheService::ServiceStat svc_stat; | |||
| RETURN_IF_NOT_OK(cs->GetStat(&svc_stat)); | |||
| flatbuffers::FlatBufferBuilder fbb; | |||
| ServiceStatMsgBuilder bld(fbb); | |||
| bld.add_num_disk_cached(svc_stat.stat_.num_disk_cached); | |||
| bld.add_num_mem_cached(svc_stat.stat_.num_mem_cached); | |||
| bld.add_avg_cache_sz(svc_stat.stat_.average_cache_sz); | |||
| bld.add_max_row_id(svc_stat.max_); | |||
| bld.add_min_row_id(svc_stat.min_); | |||
| bld.add_state(svc_stat.state_); | |||
| auto offset = bld.Finish(); | |||
| fbb.Finish(offset); | |||
| reply->set_result(fbb.GetBufferPointer(), fbb.GetSize()); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| inline Status CacheSchema(CacheService *cs, CacheRequest *rq) { | |||
| auto connection_id = rq->connection_id(); | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } else { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing schema information"); | |||
| auto &create_schema_buf = rq->buf_data(0); | |||
| RETURN_IF_NOT_OK(cs->CacheSchema(create_schema_buf.data(), create_schema_buf.size())); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| inline Status FetchSchema(CacheService *cs, CacheRequest *rq, CacheReply *reply) { | |||
| auto connection_id = rq->connection_id(); | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } else { | |||
| // We are going to use std::string to allocate and hold the result which will be eventually | |||
| // 'moved' to the protobuf message (which underneath is also a std::string) for the purpose | |||
| // to minimize memory copy. | |||
| std::string mem; | |||
| RETURN_IF_NOT_OK(cs->FetchSchema(&mem)); | |||
| reply->set_result(std::move(mem)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| inline Status BuildPhaseDone(CacheService *cs, CacheRequest *rq) { | |||
| auto connection_id = rq->connection_id(); | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } else { | |||
| // First piece of data is the cookie | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing cookie"); | |||
| auto &cookie = rq->buf_data(0); | |||
| // We can only allow to switch phase is the cookie match. | |||
| if (cookie == cs->cookie()) { | |||
| RETURN_IF_NOT_OK(cs->BuildPhaseDone()); | |||
| } else { | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch"); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheServer::PurgeCache(CacheService *cs) { | |||
| SharedLock lck(&rwLock_); | |||
| // If shutdown in progress, ignore the command. | |||
| if (global_shutdown_) { | |||
| return Status::OK(); | |||
| } | |||
| // it is already purged. Ignore it. | |||
| if (cs != nullptr) { | |||
| RETURN_IF_NOT_OK(cs->Purge()); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| inline Status GenerateClientSessionID(session_id_type session_id, CacheReply *reply) { | |||
| reply->set_result(std::to_string(session_id)); | |||
| return Status::OK(); | |||
| } | |||
| /// \brief This is the main loop the cache server thread(s) are running. | |||
| /// Each thread will pop a request and send the result back to the client using grpc | |||
| /// \return | |||
| Status CacheServer::ServerRequest() { | |||
| Status CacheServer::ServerRequest(int32_t worker_id) { | |||
| TaskManager::FindMe()->Post(); | |||
| // Loop forever until we are interrupted. | |||
| while (true) { | |||
| BaseRequest *base_rq = nullptr; | |||
| RETURN_IF_NOT_OK(cache_q_->PopFront(&base_rq)); | |||
| auto cs = GetService(base_rq->connection_id_); | |||
| auto &my_que = cache_q_->operator[](worker_id); | |||
| // Loop forever until we are interrupted or shutdown. | |||
| while (!global_shutdown_) { | |||
| CacheServerRequest *cache_req = nullptr; | |||
| RETURN_IF_NOT_OK(my_que->PopFront(&cache_req)); | |||
| auto &rq = cache_req->rq_; | |||
| auto &reply = cache_req->reply_; | |||
| CacheService *cs = nullptr; | |||
| // Request comes in roughly two sets. One set is at the cache level with a connection id. | |||
| // The other set is working at a high level and without a connection id | |||
| if (!rq.has_connection_info()) { | |||
| cs = GetService(rq.connection_id()); | |||
| } | |||
| // Except for creating a new session, we expect cs is not null. | |||
| switch (base_rq->type_) { | |||
| switch (cache_req->type_) { | |||
| case BaseRequest::RequestType::kCacheRow: { | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Cache id " + std::to_string(base_rq->connection_id_) + " not found"; | |||
| base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| // Look into the flag to see where we can find the data and | |||
| // call the appropriate method. | |||
| auto flag = rq.flag(); | |||
| if (BitTest(flag, kDataIsInSharedMemory)) { | |||
| cache_req->rc_ = FastCacheRow(cs, &rq, &reply); | |||
| } else { | |||
| auto *rq = reinterpret_cast<CacheRowRequest *>(base_rq); | |||
| // Only if the cookie matches, we can accept insert into this cache that has a build phase | |||
| if (!cs->HasBuildPhase() || rq->cookie_ == cs->cookie()) { | |||
| rq->rc_ = cs->CacheRow(rq->buffers_, &rq->row_id_from_server_); | |||
| } else { | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch"); | |||
| } | |||
| cache_req->rc_ = CacheRow(cs, &rq, &reply); | |||
| } | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kBatchFetchRows: { | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Cache id " + std::to_string(base_rq->connection_id_) + " not found"; | |||
| base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } else { | |||
| auto *rq = reinterpret_cast<BatchFetchRequest *>(base_rq); | |||
| rq->rc_ = cs->BatchFetch(rq->row_id_, &rq->mem_); | |||
| } | |||
| cache_req->rc_ = BatchFetchRows(cs, &rq, &reply); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kCreateCache: { | |||
| // If the cache is already created we still need to run the creation so that we do sanity checks on the | |||
| // client id and return the cache id back to the user. | |||
| auto *rq = reinterpret_cast<CreationCacheRequest *>(base_rq); | |||
| rq->rc_ = CreateService(rq->connection_id_, rq->cache_mem_sz, rq->flag_, &rq->cookie_); | |||
| cache_req->rc_ = CreateService(&rq, &reply); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kPurgeCache: { | |||
| if (cs != nullptr) { | |||
| base_rq->rc_ = cs->Purge(); | |||
| } else { | |||
| // it is already purged. Ignore it. | |||
| base_rq->rc_ = Status::OK(); | |||
| } | |||
| cache_req->rc_ = PurgeCache(cs); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kDestroyCache: { | |||
| if (cs != nullptr) { | |||
| // We need a strong lock to protect the map. | |||
| connection_id_type id = base_rq->connection_id_; | |||
| UniqueLock lck(&rwLock_); | |||
| // std::map will invoke the constructor of CacheService. So we don't need to do anything here. | |||
| auto n = all_caches_.erase(id); | |||
| if (n == 0) { | |||
| // It has been destroyed by another duplicate request. | |||
| MS_LOG(INFO) << "Duplicate request for " + std::to_string(id) + " to create cache service"; | |||
| } | |||
| base_rq->rc_ = Status::OK(); | |||
| } else { | |||
| // it is already destroyed. Ignore it. | |||
| base_rq->rc_ = Status::OK(); | |||
| } | |||
| cache_req->rc_ = DestroyCache(cs, &rq); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kGetStat: { | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found"; | |||
| base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } else { | |||
| auto *rq = reinterpret_cast<GetStatRequest *>(base_rq); | |||
| CacheService::ServiceStat svc_stat; | |||
| rq->rc_ = cs->GetStat(&svc_stat); | |||
| if (rq->rc_.IsOk()) { | |||
| flatbuffers::FlatBufferBuilder fbb; | |||
| ServiceStatMsgBuilder bld(fbb); | |||
| bld.add_num_disk_cached(svc_stat.stat_.num_disk_cached); | |||
| bld.add_num_mem_cached(svc_stat.stat_.num_mem_cached); | |||
| bld.add_max_row_id(svc_stat.max_); | |||
| bld.add_min_row_id(svc_stat.min_); | |||
| bld.add_state(svc_stat.state_); | |||
| auto offset = bld.Finish(); | |||
| fbb.Finish(offset); | |||
| rq->rc_ = rq->mem_.allocate(fbb.GetSize()); | |||
| if (rq->rc_.IsOk()) { | |||
| WritableSlice dest(rq->mem_.GetMutablePointer(), fbb.GetSize()); | |||
| ReadableSlice src(fbb.GetBufferPointer(), fbb.GetSize()); | |||
| RETURN_IF_NOT_OK(WritableSlice::Copy(&dest, src)); | |||
| } | |||
| } | |||
| } | |||
| cache_req->rc_ = GetStat(cs, &rq, &reply); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kCacheSchema: { | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found"; | |||
| base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } else { | |||
| auto *rq = reinterpret_cast<CacheSchemaRequest *>(base_rq); | |||
| rq->rc_ = cs->CacheSchema(rq->buf_, rq->len_of_buf_); | |||
| } | |||
| cache_req->rc_ = CacheSchema(cs, &rq); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kFetchSchema: { | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found"; | |||
| base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } else { | |||
| auto *rq = reinterpret_cast<FetchSchemaRequest *>(base_rq); | |||
| rq->rc_ = cs->FetchSchema(&rq->mem_); | |||
| } | |||
| cache_req->rc_ = FetchSchema(cs, &rq, &reply); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kBuildPhaseDone: { | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found"; | |||
| base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } else { | |||
| auto *rq = reinterpret_cast<BuildPhaseDoneRequest *>(base_rq); | |||
| // We can only allow to switch phase is the cookie match. | |||
| if (rq->cookie_ == cs->cookie()) { | |||
| rq->rc_ = cs->BuildPhaseDone(); | |||
| } else { | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch"); | |||
| } | |||
| } | |||
| cache_req->rc_ = BuildPhaseDone(cs, &rq); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kDropSession: { | |||
| cache_req->rc_ = DestroySession(&rq); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kGenerateSessionId: { | |||
| cache_req->rc_ = GenerateClientSessionID(GenerateSessionID(), &reply); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kAllocateSharedBlock: { | |||
| cache_req->rc_ = AllocateSharedMemory(&rq, &reply); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kFreeSharedBlock: { | |||
| cache_req->rc_ = FreeSharedMemory(&rq); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kStopService: { | |||
| // This command shutdowns everything. | |||
| cache_req->rc_ = GlobalShutdown(); | |||
| break; | |||
| } | |||
| default: | |||
| base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Unknown request type"); | |||
| std::string errMsg("Unknown request type : "); | |||
| errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_)); | |||
| cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } | |||
| // Notify it is done, and move on to the next request. | |||
| base_rq->wp_.Set(); | |||
| Status2CacheReply(cache_req->rc_, &reply); | |||
| cache_req->st_ = CacheServerRequest::STATE::FINISH; | |||
| // We will re-tag the request back to the grpc queue. Once it comes back from the client, | |||
| // the CacheServerRequest, i.e. the pointer cache_req, will be free | |||
| if (!global_shutdown_) { | |||
| cache_req->responder_.Finish(reply, grpc::Status::OK, cache_req); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| connection_id_type CacheServer::GetConnectionID(session_id_type session_id, uint32_t crc) const { | |||
| connection_id_type connection_id = | |||
| (static_cast<connection_id_type>(session_id) << 32u) | static_cast<connection_id_type>(crc); | |||
| return connection_id; | |||
| } | |||
| session_id_type CacheServer::GetSessionID(connection_id_type connection_id) const { | |||
| return static_cast<session_id_type>(connection_id >> 32u); | |||
| } | |||
| CacheServer::CacheServer(const std::string &spill_path, int32_t num_workers, int32_t port, | |||
| int32_t shared_meory_sz_in_gb) | |||
| : top_(spill_path), | |||
| num_workers_(num_workers), | |||
| port_(port), | |||
| shared_memory_sz_in_gb_(shared_meory_sz_in_gb), | |||
| global_shutdown_(false) {} | |||
| Status CacheServer::Run() { | |||
| RETURN_IF_NOT_OK(ServiceStart()); | |||
| // This is called by the main function and we shouldn't exit. Otherwise the main thread | |||
| // will just shutdown. So we will call some function that never return unless error. | |||
| // One good case will be simply to wait for all threads to return. | |||
| RETURN_IF_NOT_OK(vg_.join_all(Task::WaitFlag::kBlocking)); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheServer::GetFreeRequestTag(int32_t queue_id, CacheServerRequest **q) { | |||
| RETURN_UNEXPECTED_IF_NULL(q); | |||
| CacheServer &cs = CacheServer::GetInstance(); | |||
| CacheServerRequest *p; | |||
| RETURN_IF_NOT_OK(cs.free_list_->operator[](queue_id)->PopFront(&p)); | |||
| *q = p; | |||
| return Status::OK(); | |||
| } | |||
| Status CacheServer::ReturnRequestTag(CacheServerRequest *p) { | |||
| RETURN_UNEXPECTED_IF_NULL(p); | |||
| int32_t myQID = p->getQid(); | |||
| // Free any memory from the protobufs | |||
| p->~CacheServerRequest(); | |||
| // Re-initialize the memory | |||
| new (p) CacheServerRequest(myQID); | |||
| // Now we return it back to free list. | |||
| CacheServer &cs = CacheServer::GetInstance(); | |||
| RETURN_IF_NOT_OK(cs.free_list_->operator[](myQID)->Add(p)); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheServer::DestroySession(CacheRequest *rq) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing session id"); | |||
| auto drop_session_id = rq->connection_info().session_id(); | |||
| UniqueLock lck(&rwLock_); | |||
| for (auto &cs : all_caches_) { | |||
| auto connection_id = cs.first; | |||
| auto session_id = GetSessionID(connection_id); | |||
| // We can just call DestroyCache() but we are holding a lock already. Doing so will cause deadlock. | |||
| // So we will just manually do it. | |||
| if (session_id == drop_session_id) { | |||
| // std::map will invoke the destructor of CacheService. So we don't need to do anything here. | |||
| auto n = all_caches_.erase(connection_id); | |||
| MS_LOG(INFO) << "Destroy " << n << " copies of cache with id " << connection_id; | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| session_id_type CacheServer::GenerateSessionID() const { | |||
| SharedLock lock(&rwLock_); | |||
| auto mt = GetRandomDevice(); | |||
| std::uniform_int_distribution<session_id_type> distribution(0, std::numeric_limits<session_id_type>::max()); | |||
| session_id_type session_id; | |||
| bool duplicate = false; | |||
| do { | |||
| session_id = distribution(mt); | |||
| auto it = history_sessions_.find(session_id); | |||
| duplicate = (it != history_sessions_.end()); | |||
| } while (duplicate); | |||
| return session_id; | |||
| } | |||
| Status CacheServer::AllocateSharedMemory(CacheRequest *rq, CacheReply *reply) { | |||
| auto requestedSz = strtoll(rq->buf_data(0).data(), nullptr, 10); | |||
| auto shared_pool = comm_layer_->GetSharedMemoryPool(); | |||
| auto *base = shared_pool->SharedMemoryBaseAddr(); | |||
| void *p = nullptr; | |||
| RETURN_IF_NOT_OK(shared_pool->Allocate(requestedSz, &p)); | |||
| // We can't return the absolute address which makes no sense to the client. | |||
| // Instead we return the difference. | |||
| auto difference = reinterpret_cast<int64_t>(p) - reinterpret_cast<int64_t>(base); | |||
| reply->set_result(std::to_string(difference)); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheServer::FreeSharedMemory(CacheRequest *rq) { | |||
| auto shared_pool = comm_layer_->GetSharedMemoryPool(); | |||
| auto *base = shared_pool->SharedMemoryBaseAddr(); | |||
| auto addr = strtoll(rq->buf_data(0).data(), nullptr, 10); | |||
| auto p = reinterpret_cast<void *>(reinterpret_cast<int64_t>(base) + addr); | |||
| shared_pool->Deallocate(p); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheServer::RpcRequest(int32_t worker_id) { | |||
| TaskManager::FindMe()->Post(); | |||
| RETURN_IF_NOT_OK(comm_layer_->HandleRequest(worker_id)); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheServer::GlobalShutdown() { | |||
| // Let's shutdown in proper order. | |||
| bool expected = false; | |||
| if (global_shutdown_.compare_exchange_strong(expected, true)) { | |||
| MS_LOG(WARNING) << "Shutting down server."; | |||
| // Shutdown the grpc queue. No longer accept any new comer. | |||
| // The threads we spawn to work on the grpc queue will exit themselves once | |||
| // they notice the queue has been shutdown. | |||
| comm_layer_->Shutdown(); | |||
| // Now we interrupt any threads that are waiting on cache_q_ | |||
| vg_.interrupt_all(); | |||
| // The next thing to do drop all the caches. | |||
| UniqueLock lck(&rwLock_); | |||
| for (auto &it : all_caches_) { | |||
| auto id = it.first; | |||
| MS_LOG(WARNING) << "Dropping cache with connection id " << std::to_string(id); | |||
| // Wait for all outstanding work to be finished. | |||
| auto &cs = it.second; | |||
| UniqueLock cs_lock(&cs->rw_lock_); | |||
| // std::map will invoke the destructor of CacheService. So we don't need to do anything here. | |||
| (void)all_caches_.erase(id); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheServer::Builder::SanityCheck() { | |||
| if (shared_memory_sz_in_gb_ <= 0) { | |||
| RETURN_STATUS_UNEXPECTED("Shared memory size (in GB unit) must be positive"); | |||
| } | |||
| if (num_workers_ <= 0) { | |||
| RETURN_STATUS_UNEXPECTED("Number of parallel workers must be positive"); | |||
| } | |||
| if (!top_.empty()) { | |||
| auto p = top_.data(); | |||
| if (p[0] != '/') { | |||
| RETURN_STATUS_UNEXPECTED("Spilling directory must be an absolute path"); | |||
| } | |||
| // Check if the spill directory is writable | |||
| Path spill(top_); | |||
| auto t = spill / Services::GetUniqueID(); | |||
| Status rc = t.CreateDirectory(); | |||
| if (rc.IsOk()) { | |||
| rc = t.Remove(); | |||
| } | |||
| if (rc.IsError()) { | |||
| RETURN_STATUS_UNEXPECTED("Spilling directory is not writable\n" + rc.ToString()); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| CacheServer::CacheServer(const std::string &spill_path, int32_t num_workers) | |||
| : top_(spill_path), num_workers_(num_workers) {} | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -24,8 +24,11 @@ | |||
| #include <utility> | |||
| #include <vector> | |||
| #include <map> | |||
| #include <set> | |||
| #include "minddata/dataset/engine/cache/cache_service.h" | |||
| #include "minddata/dataset/engine/cache/cache_grpc_server.h" | |||
| #include "minddata/dataset/core/tensor.h" | |||
| #include "minddata/dataset/util/allocator.h" | |||
| #include "minddata/dataset/util/arena.h" | |||
| #include "minddata/dataset/util/cache_pool.h" | |||
| #include "minddata/dataset/util/lock.h" | |||
| @@ -37,43 +40,131 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class BaseRequest; | |||
| /// \brief A server which provides CacheService services. | |||
| class CacheServer : public Service { | |||
| public: | |||
| friend class Services; | |||
| using cache_index = std::map<connection_id_type, std::unique_ptr<CacheService>>; | |||
| class Builder { | |||
| public: | |||
| Builder() : top_("/tmp"), num_workers_(32), port_(50052), shared_memory_sz_in_gb_(4) {} | |||
| /// \brief Getter functions | |||
| const std::string &getTop() const { return top_; } | |||
| int32_t getNumWorkers() const { return num_workers_; } | |||
| int32_t getPort() const { return port_; } | |||
| int32_t getSharedMemorySzInGb() const { return shared_memory_sz_in_gb_; } | |||
| Builder &SetRootDirectory(std::string root) { | |||
| top_ = std::move(root); | |||
| return *this; | |||
| } | |||
| Builder &SetNumWorkers(int32_t n) { | |||
| num_workers_ = n; | |||
| return *this; | |||
| } | |||
| Builder &SetPort(int32_t p) { | |||
| port_ = p; | |||
| return *this; | |||
| } | |||
| Builder &SetSharedMemorySizeInGB(int32_t sz) { | |||
| shared_memory_sz_in_gb_ = sz; | |||
| return *this; | |||
| } | |||
| Status SanityCheck(); | |||
| void Print(std::ostream &out) const { | |||
| out << "Summary of the cache server configuration\n" | |||
| << "Spill directory: " << getTop() << "\n" | |||
| << "Number of parallel workers: " << getNumWorkers() << "\n" | |||
| << "Tcp/ip port: " << getPort() << "\n" | |||
| << "Shared memory size (in GB): " << getSharedMemorySzInGb(); | |||
| } | |||
| friend std::ostream &operator<<(std::ostream &out, const Builder &bld) { | |||
| bld.Print(out); | |||
| return out; | |||
| } | |||
| Status Build() { | |||
| RETURN_IF_NOT_OK(SanityCheck()); | |||
| // We need to bring up the Task Manager by bringing up the Services singleton. | |||
| RETURN_IF_NOT_OK(Services::CreateInstance()); | |||
| RETURN_IF_NOT_OK(CacheServer::CreateInstance(top_, num_workers_, port_, shared_memory_sz_in_gb_)); | |||
| return Status::OK(); | |||
| } | |||
| private: | |||
| std::string top_; | |||
| int32_t num_workers_; | |||
| int32_t port_; | |||
| int32_t shared_memory_sz_in_gb_; | |||
| }; | |||
| CacheServer(const CacheServer &) = delete; | |||
| CacheServer &operator=(const CacheServer &) = delete; | |||
| CacheServer(CacheServer &&) = delete; | |||
| CacheServer &operator=(CacheServer &) = delete; | |||
| static CacheServer &GetInstance() noexcept { return Services::getCacheServer(); } | |||
| Status DoServiceStart() override; | |||
| Status DoServiceStop() override; | |||
| ~CacheServer() { (void)ServiceStop(); } | |||
| static Status CreateInstance(const std::string &spill_path, int32_t num_workers, int32_t port, | |||
| int32_t shared_memory_sz) { | |||
| std::call_once(init_instance_flag_, [&]() -> Status { | |||
| auto &svcManager = Services::GetInstance(); | |||
| RETURN_IF_NOT_OK(svcManager.AddHook(&instance_, spill_path, num_workers, port, shared_memory_sz)); | |||
| return Status::OK(); | |||
| }); | |||
| return Status::OK(); | |||
| } | |||
| static CacheServer &GetInstance() { return *instance_; } | |||
| /// \brief For the current demonstration, a cache client contacts cache server using a Queue. | |||
| /// \param rq | |||
| /// \return Status object | |||
| Status PushRequest(BaseRequest *rq) { | |||
| Status PushRequest(int32_t queue_id, CacheServerRequest *rq) { | |||
| RETURN_UNEXPECTED_IF_NULL(rq); | |||
| RETURN_IF_NOT_OK(cache_q_->Add(rq)); | |||
| RETURN_IF_NOT_OK(cache_q_->operator[](queue_id)->Add(rq)); | |||
| return Status::OK(); | |||
| } | |||
| /// \\brief Kick off server threads. Never return unless error out. | |||
| Status Run(); | |||
| /// \brief Get a free tag | |||
| /// \param q[in] pointer to a pointer to a CacheServerRequest | |||
| /// \return Status object | |||
| static Status GetFreeRequestTag(int32_t queue_id, CacheServerRequest **q); | |||
| /// \brief Return a tag to the free list | |||
| /// \param p[in] pointer to already finished CacheServerRequest tag | |||
| /// \return Status object | |||
| static Status ReturnRequestTag(CacheServerRequest *p); | |||
| private: | |||
| static std::once_flag init_instance_flag_; | |||
| static CacheServer *instance_; | |||
| mutable RWLock rwLock_; | |||
| std::string top_; | |||
| cache_index all_caches_; | |||
| std::shared_ptr<Queue<BaseRequest *>> cache_q_; | |||
| std::set<session_id_type> history_sessions_; | |||
| std::shared_ptr<QueueList<CacheServerRequest *>> cache_q_; | |||
| std::shared_ptr<QueueList<CacheServerRequest *>> free_list_; | |||
| std::vector<std::unique_ptr<MemGuard<CacheServerRequest, Allocator<CacheServerRequest>>>> tag_; | |||
| std::shared_ptr<CacheServerGreeterImpl> comm_layer_; | |||
| std::shared_ptr<MemoryPool> mp_; | |||
| TaskGroup vg_; | |||
| int32_t num_workers_; | |||
| int32_t port_; | |||
| int32_t shared_memory_sz_in_gb_; | |||
| std::atomic<bool> global_shutdown_; | |||
| /// \brief Constructor | |||
| /// \param spill_path Top directory for spilling buffers to. | |||
| /// \param num_workers Number of threads for handling requests. | |||
| explicit CacheServer(const std::string &spill_path, int32_t num_workers = 3); | |||
| explicit CacheServer(const std::string &spill_path, int32_t num_workers, int32_t port, int32_t share_memory_sz_in_gb); | |||
| /// \brief Locate a cache service from connection id. | |||
| /// \return Pointer to cache service. Null if not found | |||
| @@ -82,16 +173,65 @@ class CacheServer : public Service { | |||
| /// \brief Create a cache service. We allow multiple clients to create the same cache service. | |||
| /// Subsequent duplicate requests are ignored. The first cache client to create the service will be given | |||
| /// a special unique cookie. | |||
| /// \param[in] connection_id This is from a Cache client. | |||
| /// \param[in] cache_mem_sz | |||
| /// \param[in] flag | |||
| /// \param[out] out_cookie Only the first cache client will be given a special cookie to identify the creator | |||
| /// \return Status object | |||
| Status CreateService(connection_id_type connection_id, uint64_t cache_mem_sz, BaseRequest::CreateCacheFlag flag, | |||
| std::string *out_cookie); | |||
| Status CreateService(CacheRequest *rq, CacheReply *reply); | |||
| /// \brief Destroy a cache service | |||
| /// \param cs | |||
| /// \param rq | |||
| /// \return | |||
| Status DestroyCache(CacheService *cs, CacheRequest *rq); | |||
| Status PurgeCache(CacheService *cs); | |||
| /// \brief Entry point for all internal server threads. | |||
| Status ServerRequest(int32_t worker_id); | |||
| /// \brief Entry point for all grpc threads. | |||
| /// \return | |||
| Status RpcRequest(int32_t worker_id); | |||
| Status DestroySession(CacheRequest *rq); | |||
| /// \brief Create a connection id from a session id and a crc | |||
| /// \param session_id | |||
| /// \param crc | |||
| /// \return connection id | |||
| connection_id_type GetConnectionID(session_id_type session_id, uint32_t crc) const; | |||
| /// \brief Extract the session id from a connection id | |||
| /// \param connection_id | |||
| /// \return session id | |||
| session_id_type GetSessionID(connection_id_type connection_id) const; | |||
| /// \brief Generate a session ID for the client | |||
| /// \return Session ID | |||
| session_id_type GenerateSessionID() const; | |||
| /// \brief Handle kAllocateSharedBlock request | |||
| /// \param rq CacheRequest | |||
| /// \param reply CacheReply | |||
| /// \return Status object | |||
| Status AllocateSharedMemory(CacheRequest *rq, CacheReply *reply); | |||
| /// \brief Handle kFreeSharedBlock request | |||
| /// \param rq | |||
| /// \return Status object | |||
| Status FreeSharedMemory(CacheRequest *rq); | |||
| /// \brief Entry point for all server threads. | |||
| Status ServerRequest(); | |||
| /// \brief Handle kFastCacheRow request | |||
| /// \return Status object | |||
| Status FastCacheRow(CacheService *cs, CacheRequest *rq, CacheReply *reply); | |||
| /// \brief Internal function to do row batch fetch | |||
| /// \param cs CacheService | |||
| /// \param rq Request | |||
| /// \param reply Reply | |||
| /// \return | |||
| Status BatchFetchRows(CacheService *cs, CacheRequest *rq, CacheReply *reply); | |||
| /// \brief A proper shutdown of the server | |||
| /// \return Status object | |||
| Status GlobalShutdown(); | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -76,7 +76,7 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type | |||
| *row_id_generated = GetNextRowId(); | |||
| // Some debug information on how many rows we have generated so far. | |||
| if ((*row_id_generated) % 1000 == 0) { | |||
| MS_LOG(DEBUG) << "Number of rows cached: " << *row_id_generated; | |||
| MS_LOG(DEBUG) << "Number of rows cached: " << (*row_id_generated) + 1; | |||
| } | |||
| } else { | |||
| if (msg->row_id() < 0) { | |||
| @@ -114,6 +114,45 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type | |||
| RETURN_STATUS_UNEXPECTED(e.what()); | |||
| } | |||
| } | |||
| Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_generated) { | |||
| SharedLock rw(&rw_lock_); | |||
| RETURN_UNEXPECTED_IF_NULL(row_id_generated); | |||
| if (st_ == State::kFetchPhase) { | |||
| // For this kind of cache service, once we are done with the build phase into fetch phase, we can't | |||
| // allow other to cache more rows. | |||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | |||
| } | |||
| try { | |||
| // If we don't need to generate id, we need to find it from the buffer. | |||
| if (generate_id_) { | |||
| *row_id_generated = GetNextRowId(); | |||
| // Some debug information on how many rows we have generated so far. | |||
| if ((*row_id_generated) % 1000 == 0) { | |||
| MS_LOG(DEBUG) << "Number of rows cached: " << (*row_id_generated) + 1; | |||
| } | |||
| } else { | |||
| auto msg = GetTensorRowHeaderMsg(src.GetPointer()); | |||
| if (msg->row_id() < 0) { | |||
| std::string errMsg = "Expect positive row id: " + std::to_string(msg->row_id()); | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| *row_id_generated = msg->row_id(); | |||
| } | |||
| // Now we cache the flat buffer. | |||
| CachePool::key_type key; | |||
| RETURN_IF_NOT_OK(cp_->Insert({src}, &key)); | |||
| Status rc = map_->DoInsert(*row_id_generated, key); | |||
| if (rc == Status(StatusCode::kDuplicateKey)) { | |||
| MS_LOG(DEBUG) << "Ignoring duplicate key."; | |||
| } else { | |||
| RETURN_IF_NOT_OK(rc); | |||
| } | |||
| return Status::OK(); | |||
| } catch (const std::exception &e) { | |||
| RETURN_STATUS_UNEXPECTED(e.what()); | |||
| } | |||
| } | |||
| std::ostream &operator<<(std::ostream &out, const CacheService &cs) { | |||
| // Then show any custom derived-internal stuff | |||
| out << "\nCache memory size: " << cs.cache_mem_sz_; | |||
| @@ -155,20 +194,15 @@ Status CacheService::GetStat(CacheService::ServiceStat *out) { | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheService::BatchFetch(const std::vector<row_id_type> &v, MemGuard<uint8_t> *out) const { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| Status CacheService::PreBatchFetch(const std::vector<row_id_type> &v, std::vector<key_size_pair> *out, | |||
| int64_t *mem_sz) { | |||
| SharedLock rw(&rw_lock_); | |||
| if (st_ == State::kBuildPhase) { | |||
| // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. | |||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | |||
| } | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| RETURN_UNEXPECTED_IF_NULL(mem_sz); | |||
| const auto num_elements = v.size(); | |||
| int64_t mem_sz = (num_elements + 1) * sizeof(int64_t); | |||
| int64_t data_offset = mem_sz; | |||
| std::vector<int64_t> sz_v; | |||
| std::vector<CachePool::key_type> keys; | |||
| sz_v.reserve(num_elements); | |||
| keys.reserve(num_elements); | |||
| *mem_sz = (num_elements + 1) * sizeof(int64_t); | |||
| (*out).reserve(num_elements); | |||
| for (auto row_id : v) { | |||
| auto r = map_->Search(row_id); | |||
| if (r.second) { | |||
| @@ -180,25 +214,33 @@ Status CacheService::BatchFetch(const std::vector<row_id_type> &v, MemGuard<uint | |||
| errMsg += std::to_string(key); | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| keys.push_back(key); | |||
| sz_v.push_back(sz); | |||
| mem_sz += sz; | |||
| (*out).emplace_back(key, sz); | |||
| (*mem_sz) += sz; | |||
| } else { | |||
| keys.push_back(-1); | |||
| sz_v.push_back(0); | |||
| (*out).emplace_back(-1, 0); | |||
| } | |||
| } | |||
| MemGuard<uint8_t> mem; | |||
| RETURN_IF_NOT_OK(mem.allocate(mem_sz)); | |||
| auto *offset_array = reinterpret_cast<int64_t *>(mem.GetMutablePointer()); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheService::BatchFetch(const std::vector<row_id_type> &v, const std::vector<key_size_pair> &info, | |||
| WritableSlice *out) const { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| SharedLock rw(&rw_lock_); | |||
| if (st_ == State::kBuildPhase) { | |||
| // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. | |||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | |||
| } | |||
| const auto num_elements = v.size(); | |||
| int64_t data_offset = (num_elements + 1) * sizeof(int64_t); | |||
| auto *offset_array = reinterpret_cast<int64_t *>(out->GetMutablePointer()); | |||
| offset_array[0] = data_offset; | |||
| WritableSlice all(mem.GetMutablePointer(), mem.GetSizeInBytes()); | |||
| for (auto i = 0; i < num_elements; ++i) { | |||
| auto sz = sz_v.at(i); | |||
| auto sz = info.at(i).second; | |||
| offset_array[i + 1] = offset_array[i] + sz; | |||
| if (sz > 0) { | |||
| WritableSlice row_data(all, offset_array[i], sz); | |||
| auto key = keys.at(i); | |||
| WritableSlice row_data(*out, offset_array[i], sz); | |||
| auto key = info.at(i).first; | |||
| size_t bytesRead = 0; | |||
| RETURN_IF_NOT_OK(cp_->Read(key, &row_data, &bytesRead)); | |||
| if (bytesRead != sz) { | |||
| @@ -208,7 +250,6 @@ Status CacheService::BatchFetch(const std::vector<row_id_type> &v, MemGuard<uint | |||
| } | |||
| } | |||
| } | |||
| *out = std::move(mem); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheService::CacheSchema(const void *buf, int64_t len) { | |||
| @@ -232,18 +273,26 @@ Status CacheService::CacheSchema(const void *buf, int64_t len) { | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheService::FetchSchema(MemGuard<uint8_t> *out) const { | |||
| Status CacheService::FetchSchema(std::string *out) const { | |||
| SharedLock rw(&rw_lock_); | |||
| if (st_ == State::kBuildPhase) { | |||
| // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. | |||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | |||
| } | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| MemGuard<uint8_t> mem; | |||
| // We are going to use std::string to allocate and hold the result which will be eventually | |||
| // 'moved' to the protobuf message (which underneath is also a std::string) for the purpose | |||
| // to minimize memory copy. | |||
| std::string mem; | |||
| if (schema_key_ >= 0) { | |||
| auto len = cp_->GetSize(schema_key_); | |||
| RETURN_IF_NOT_OK(mem.allocate(len)); | |||
| auto slice = WritableSlice(mem.GetMutablePointer(), len); | |||
| try { | |||
| mem.resize(len); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(mem.capacity() >= len, "Programming error"); | |||
| } catch (const std::bad_alloc &e) { | |||
| return Status(StatusCode::kOutOfMemory); | |||
| } | |||
| auto slice = WritableSlice(mem.data(), len); | |||
| RETURN_IF_NOT_OK(cp_->Read(schema_key_, &slice)); | |||
| *out = std::move(mem); | |||
| } else { | |||
| @@ -28,7 +28,6 @@ | |||
| #include "minddata/dataset/core/global_context.h" | |||
| #include "minddata/dataset/core/tensor.h" | |||
| #include "minddata/dataset/engine/cache/cache_request.h" | |||
| #include "minddata/dataset/engine/cache/de_tensor_generated.h" | |||
| #include "minddata/dataset/util/arena.h" | |||
| #include "minddata/dataset/util/btree.h" | |||
| #include "minddata/dataset/util/cache_pool.h" | |||
| @@ -38,7 +37,8 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| struct CacheStat; | |||
| /// Some typedef used for BatchFetch | |||
| using key_size_pair = std::pair<CachePool::key_type, size_t>; | |||
| /// \brief A cache service for storing/fetching buffers to in memory cache and may spill to disk the cache service is | |||
| /// created to support spilling | |||
| class CacheService : public Service { | |||
| @@ -69,12 +69,26 @@ class CacheService : public Service { | |||
| /// \param[out] row_id_generated The row id assigned to this row if any | |||
| /// \return Status object | |||
| Status CacheRow(const std::vector<const void *> &buf, row_id_type *row_id_generated); | |||
| /// \brief A fast version of CacheRow where all the data is already in one contiguous piece. | |||
| /// \param src Slice of the data | |||
| /// \param row_id_generated | |||
| /// \return Status object | |||
| Status FastCacheRow(const ReadableSlice &src, row_id_type *row_id_generated); | |||
| /// \brief This function is used in preparation for batch fetching. | |||
| /// It calculates how much memory we should allocate and which row id are present. | |||
| /// \param[in/out] Pointer to vector of <CachePool::key_type, size_t> | |||
| /// \param[in/out] mem_sz how much memory is required to batch fetch | |||
| /// \return Status object | |||
| Status PreBatchFetch(const std::vector<row_id_type> &v, std::vector<key_size_pair> *, int64_t *mem_sz); | |||
| /// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded | |||
| /// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row. | |||
| /// \param[in] v A vector of row id. | |||
| /// \param[out] out A contiguous memory buffer that holds the requested rows. | |||
| /// \return Status object | |||
| Status BatchFetch(const std::vector<row_id_type> &v, MemGuard<uint8_t> *out) const; | |||
| Status BatchFetch(const std::vector<row_id_type> &v, const std::vector<key_size_pair> &, WritableSlice *out) const; | |||
| /// \brief Getter function | |||
| /// \return Spilling path | |||
| @@ -102,7 +116,7 @@ class CacheService : public Service { | |||
| /// \brief Fetch schema | |||
| /// \param out A contiguous memory that contains the serialized form of schema. | |||
| /// \return Status object | |||
| Status FetchSchema(MemGuard<uint8_t> *out) const; | |||
| Status FetchSchema(std::string *out) const; | |||
| /// \brief Purge the content of a cache | |||
| /// \return Status object | |||
| Status Purge(); | |||
| @@ -60,10 +60,11 @@ table TensorRowIds { | |||
| } | |||
| /// Statistics returned from each cache service | |||
| /// \note It must match CacheService::ServiceStat | |||
| /// \note It must match CacheServiceStat | |||
| table ServiceStatMsg { | |||
| num_mem_cached:int64; | |||
| num_disk_cached:int64; | |||
| avg_cache_sz:int64; | |||
| min_row_id:int64; | |||
| max_row_id:int64; | |||
| state:int8; | |||
| @@ -79,3 +80,15 @@ table ColumnNameMsg { | |||
| table SchemaMsg { | |||
| column:[ColumnNameMsg]; | |||
| } | |||
| /// Part of the CreateCacheRequest | |||
| table CreateCacheRequestMsg { | |||
| cache_mem_sz:int64; | |||
| flag:uint32; | |||
| } | |||
| /// Return result of CreateCacheRequest | |||
| table CreateCacheReplyMsg { | |||
| connection_id:int64; | |||
| cookie:string; | |||
| } | |||
| @@ -0,0 +1,45 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_STUB_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_STUB_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include "proto/cache_grpc.pb.h" | |||
| #include "minddata/dataset/engine/cache/cache_common.h" | |||
| #include "minddata/dataset/engine/cache/cache_request.h" | |||
| #include "minddata/dataset/util/service.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class CacheClientGreeter : public Service { | |||
| public: | |||
| explicit CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_workers) {} | |||
| ~CacheClientGreeter() override {} | |||
| Status DoServiceStart() override { RETURN_STATUS_UNEXPECTED("Not supported"); } | |||
| Status DoServiceStop() override { RETURN_STATUS_UNEXPECTED("Not supported"); } | |||
| void *SharedMemoryBaseAddr() { return nullptr; } | |||
| Status HandleRequest(std::shared_ptr<BaseRequest> rq) { RETURN_STATUS_UNEXPECTED("Not supported"); } | |||
| Status AttachToSharedMemory(int32_t port, bool *local_bypass) { RETURN_STATUS_UNEXPECTED("Not supported"); } | |||
| protected: | |||
| private: | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_STUB_H_ | |||
| @@ -16,6 +16,7 @@ | |||
| #include "minddata/dataset/engine/datasetops/cache_base_op.h" | |||
| #include <iomanip> | |||
| #include <iostream> | |||
| #include <utility> | |||
| #include "minddata/dataset/engine/execution_tree.h" | |||
| namespace mindspore { | |||
| @@ -47,22 +48,39 @@ Status CacheBase::Reset() { | |||
| } | |||
| CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, | |||
| std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler) | |||
| : ParallelOp(num_workers, op_connector_size, sampler), | |||
| cache_client_(cache_client), | |||
| : ParallelOp(num_workers, op_connector_size, std::move(sampler)), | |||
| row_cnt_(0), | |||
| num_cache_miss_(0), | |||
| cache_client_(std::move(cache_client)), | |||
| rows_per_buffer_(rows_per_buf), | |||
| // We can cause deadlock if this internal Connector size is too small. | |||
| keys_miss_(num_workers_, 1, connector_capacity_) { | |||
| keys_miss_(num_workers_, 1, connector_capacity_), | |||
| prefetch_size_(cache_client_->getPrefetchSize()) { | |||
| io_block_queues_.Init(num_workers, op_connector_size); | |||
| prefetch_queues_.Init(num_workers, op_connector_size); | |||
| sampler_queue_ = std::make_unique<Queue<std::shared_ptr<Tensor>>>(op_connector_size); | |||
| } | |||
| // Common function to fetch samples from the sampler and send them using the io_block_queues to | |||
| // the parallel workers | |||
| Status CacheBase::FetchSamplesToWorkers() { | |||
| int64_t buf_cnt = 0; | |||
| int64_t wait_cnt = 0; | |||
| // Kick off several threads which will prefetch prefetch_size_ rows in advance. The rows_per_buffers_ | |||
| // is too small (1 by default) and won't help performance. | |||
| RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("Dispatcher", std::bind(&CacheBase::Dispatcher, this))); | |||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CacheBase::Prefetcher, this, std::placeholders::_1))); | |||
| // Instead of sending sampler id to WorkerEntry, we send them to the Prefetcher which will redirect them | |||
| // to the WorkerEntry. | |||
| do { | |||
| epoch_sync_.Clear(); | |||
| if (AllowCacheMiss() && wait_cnt > 0) { | |||
| MS_LOG(WARNING) << "Epoch: " << wait_cnt << " Cache Miss : " << num_cache_miss_ | |||
| << " Total number of rows : " << row_cnt_; | |||
| } | |||
| num_cache_miss_ = 0; | |||
| row_cnt_ = 0; | |||
| ++wait_cnt; | |||
| std::vector<row_id_type> keys; | |||
| int64_t row_cnt = 0; | |||
| keys.reserve(rows_per_buffer_); | |||
| std::unique_ptr<DataBuffer> sampler_buffer; | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||
| @@ -70,10 +88,13 @@ Status CacheBase::FetchSamplesToWorkers() { | |||
| TensorRow sample_row; | |||
| RETURN_IF_NOT_OK(sampler_buffer->PopRow(&sample_row)); | |||
| std::shared_ptr<Tensor> sample_ids = sample_row[0]; | |||
| // Send the sampler tensor to other thread for prefetching. We are using shared pointer so it | |||
| // won't go out scope until it is really not in use. | |||
| RETURN_IF_NOT_OK(sampler_queue_->Add(sample_ids)); | |||
| for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); itr++) { | |||
| keys.push_back(*itr); | |||
| ++row_cnt; | |||
| if (row_cnt % rows_per_buffer_ == 0) { | |||
| ++row_cnt_; | |||
| if (row_cnt_ % rows_per_buffer_ == 0) { | |||
| auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)); | |||
| RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk))); | |||
| keys.clear(); | |||
| @@ -90,7 +111,7 @@ Status CacheBase::FetchSamplesToWorkers() { | |||
| io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | |||
| // If repeat but the not last repeat, wait for reset. | |||
| if (!IsLastIteration()) { | |||
| MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << ++wait_cnt << " Buffer sent " << buf_cnt; | |||
| MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << wait_cnt << " Buffer sent " << buf_cnt; | |||
| RETURN_IF_NOT_OK(epoch_sync_.Wait()); | |||
| } else { | |||
| // We can break out from the loop. | |||
| @@ -101,13 +122,21 @@ Status CacheBase::FetchSamplesToWorkers() { | |||
| // Flow the eof before exit | |||
| RETURN_IF_NOT_OK( | |||
| io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof))); | |||
| // Ask all the workers to quit. | |||
| // Shutdown threads | |||
| std::shared_ptr<Tensor> empty; | |||
| RETURN_IF_NOT_OK(sampler_queue_->Add(std::move(empty))); | |||
| for (int32_t i = 0; i < num_workers_; i++) { | |||
| RETURN_IF_NOT_OK( | |||
| io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone))); | |||
| } | |||
| // Dump the last epoch result (approximately) without waiting for the worker threads to come back. | |||
| if (AllowCacheMiss()) { | |||
| MS_LOG(WARNING) << "Epoch: " << wait_cnt << " Cache Miss : " << num_cache_miss_ | |||
| << " Total number of rows : " << row_cnt_; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheBase::FetchFromCache(int32_t worker_id) { | |||
| int64_t buffer_id = worker_id; | |||
| std::unique_ptr<IOBlock> blk; | |||
| @@ -133,23 +162,16 @@ Status CacheBase::FetchFromCache(int32_t worker_id) { | |||
| } | |||
| std::unique_ptr<DataBuffer> db = std::make_unique<DataBuffer>(buffer_id, DataBuffer::kDeBFlagNone); | |||
| std::unique_ptr<TensorQTable> que = std::make_unique<TensorQTable>(); | |||
| TensorTable ttbl; | |||
| RETURN_IF_NOT_OK(cache_client_->GetRows(keys, &ttbl)); | |||
| auto row_it = ttbl.begin(); | |||
| std::vector<row_id_type> cache_miss; | |||
| cache_miss.reserve(keys.size()); | |||
| for (auto row_id : keys) { | |||
| auto &row = *row_it; | |||
| TensorRow row; | |||
| // Block until the row shows up in the pool. | |||
| RETURN_IF_NOT_OK(prefetch_.PopFront(row_id, &row)); | |||
| if (row.empty()) { | |||
| if (AllowCacheMiss()) { | |||
| cache_miss.push_back(row_id); | |||
| } else { | |||
| std::string errMsg = "Row id " + std::to_string(row_id) + " not found."; | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| cache_miss.push_back(row_id); | |||
| } | |||
| que->push_back(std::move(row)); | |||
| ++row_it; | |||
| } | |||
| db->set_tensor_table(std::move(que)); | |||
| if (AllowCacheMiss()) { | |||
| @@ -162,12 +184,17 @@ Status CacheBase::FetchFromCache(int32_t worker_id) { | |||
| } while (true); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheBase::RegisterResources() { | |||
| RETURN_IF_NOT_OK(epoch_sync_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(prefetch_queues_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(sampler_queue_->Register(tree_->AllTasks())); | |||
| return Status::OK(); | |||
| } | |||
| CacheBase::~CacheBase() {} | |||
| CacheBase::~CacheBase() = default; | |||
| Status CacheBase::UpdateColumnMapFromCache() { | |||
| Status rc; | |||
| // Get the schema from the server. It may not be there yet. So tolerate the error. | |||
| @@ -180,5 +207,77 @@ Status CacheBase::UpdateColumnMapFromCache() { | |||
| } | |||
| return rc; | |||
| } | |||
| Status CacheBase::Dispatcher() { | |||
| TaskManager::FindMe()->Post(); | |||
| int64_t buf_cnt = 0; | |||
| int64_t num_row = 0; | |||
| std::vector<row_id_type> keys; | |||
| keys.reserve(prefetch_size_); | |||
| do { | |||
| keys.clear(); | |||
| std::shared_ptr<Tensor> sample_ids; | |||
| RETURN_IF_NOT_OK(sampler_queue_->PopFront(&sample_ids)); | |||
| if (sample_ids == nullptr) { | |||
| // A null shared pointer signal times to quit. | |||
| // Also signal all prefetchers to quit. | |||
| for (int32_t i = 0; i < num_workers_; i++) { | |||
| RETURN_IF_NOT_OK( | |||
| prefetch_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone))); | |||
| } | |||
| break; | |||
| } | |||
| // Now we distribute the sampler ids to each prefetcher according to the prefetch size. | |||
| for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); itr++) { | |||
| keys.push_back(*itr); | |||
| ++num_row; | |||
| if (num_row % prefetch_size_ == 0) { | |||
| auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)); | |||
| RETURN_IF_NOT_OK(prefetch_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk))); | |||
| keys.clear(); | |||
| } | |||
| } | |||
| // Send the remaining sample id | |||
| if (!keys.empty()) { | |||
| auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)); | |||
| RETURN_IF_NOT_OK(prefetch_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk))); | |||
| } | |||
| } while (true); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheBase::Prefetcher(int32_t worker_id) { | |||
| TaskManager::FindMe()->Post(); | |||
| std::vector<row_id_type> prefetch_keys; | |||
| prefetch_keys.reserve(prefetch_size_); | |||
| do { | |||
| prefetch_keys.clear(); | |||
| std::unique_ptr<IOBlock> blk; | |||
| RETURN_IF_NOT_OK(prefetch_queues_[worker_id]->PopFront(&blk)); | |||
| RETURN_IF_NOT_OK(blk->GetKeys(&prefetch_keys)); | |||
| if (prefetch_keys.empty()) { | |||
| // Empty keys mean time to quit. | |||
| break; | |||
| } | |||
| TensorTable ttbl; | |||
| RETURN_IF_NOT_OK(cache_client_->GetRows(prefetch_keys, &ttbl)); | |||
| auto row_it = ttbl.begin(); | |||
| for (auto row_id : prefetch_keys) { | |||
| auto &row = *row_it; | |||
| if (row.empty()) { | |||
| if (AllowCacheMiss()) { | |||
| ++num_cache_miss_; | |||
| } else { | |||
| std::string errMsg = "Row id " + std::to_string(row_id) + " not found."; | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| } | |||
| // Put the prefetch row into the pool and wake up any WorkerEntry to wait for the row | |||
| RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row))); | |||
| ++row_it; | |||
| } | |||
| } while (true); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -16,6 +16,8 @@ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_ | |||
| #include <atomic> | |||
| #include <deque> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| @@ -28,8 +30,9 @@ | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" | |||
| #include "minddata/dataset/util/queue.h" | |||
| #include "minddata/dataset/util/queue_map.h" | |||
| #include "minddata/dataset/util/semaphore.h" | |||
| #include "minddata/dataset/util/wait_post.h" | |||
| #include "minddata/dataset/engine/datasetops/cache_base_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// \brief This is the base class for CacheOp and CacheLookupOp which share many similarities. | |||
| @@ -82,10 +85,13 @@ class CacheBase : public ParallelOp { | |||
| protected: | |||
| constexpr static int32_t eoe_row_id = -1; | |||
| int64_t row_cnt_; | |||
| std::atomic<int64_t> num_cache_miss_; | |||
| std::shared_ptr<CacheClient> cache_client_; | |||
| WaitPost epoch_sync_; | |||
| int32_t rows_per_buffer_; | |||
| Connector<std::vector<row_id_type>> keys_miss_; | |||
| QueueMap<row_id_type, TensorRow> prefetch_; | |||
| /// \brief Common function to register resources for interrupt | |||
| /// \note Derived should override this function for extra resources to be registered | |||
| @@ -103,7 +109,15 @@ class CacheBase : public ParallelOp { | |||
| private: | |||
| constexpr static int32_t connector_capacity_ = 1024; | |||
| int32_t prefetch_size_; | |||
| QueueList<std::unique_ptr<IOBlock>> io_block_queues_; | |||
| QueueList<std::unique_ptr<IOBlock>> prefetch_queues_; | |||
| std::unique_ptr<Queue<std::shared_ptr<Tensor>>> sampler_queue_; | |||
| Status Dispatcher(); | |||
| /// \brief Prefetcher. It prefetch the rows from cache server | |||
| /// \return Status object. | |||
| Status Prefetcher(int32_t worker_id); | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -16,8 +16,10 @@ | |||
| #include "minddata/dataset/engine/datasetops/cache_merge_op.h" | |||
| #include <algorithm> | |||
| #include <chrono> | |||
| #include <functional> | |||
| #include <iomanip> | |||
| #include <utility> | |||
| #include "minddata/dataset/core/config_manager.h" | |||
| #include "minddata/dataset/core/constants.h" | |||
| #include "minddata/dataset/core/global_context.h" | |||
| @@ -41,9 +43,13 @@ void CacheMergeOp::Print(std::ostream &out, bool show_all) const { | |||
| out << "\n\n"; | |||
| } | |||
| } | |||
| CacheMergeOp::CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners, | |||
| std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<Sampler> &sampler) | |||
| : ParallelOp(numWorkers, opConnectorSize, sampler), num_cleaners_(numCleaners), cache_client_(cache_client) {} | |||
| : ParallelOp(numWorkers, opConnectorSize, sampler), | |||
| num_cleaners_(numCleaners), | |||
| cache_client_(std::move(cache_client)) {} | |||
| Status CacheMergeOp::operator()() { | |||
| // A queue of row id to let cleaner send cache miss rows to the cache server | |||
| // We don't want a small queue as this will block the parallel op workers. | |||
| @@ -62,6 +68,7 @@ Status CacheMergeOp::operator()() { | |||
| TaskManager::FindMe()->Post(); | |||
| return Status::OK(); | |||
| } | |||
| // Each parallel worker will pop from the CacheHit stream. If there is a missing TensorRow, we will wait | |||
| // until it shows up in the pool. | |||
| Status CacheMergeOp::WorkerEntry(int32_t worker_id) { | |||
| @@ -82,10 +89,8 @@ Status CacheMergeOp::WorkerEntry(int32_t worker_id) { | |||
| RETURN_IF_NOT_OK(db_ptr->PopRow(&row)); | |||
| if (row.empty()) { | |||
| auto row_id = row.getId(); | |||
| TensorRowRequest *rq = nullptr; | |||
| RETURN_IF_NOT_OK(GetRq(row_id, &rq)); | |||
| // Block until the row shows up in the pool. | |||
| RETURN_IF_NOT_OK(rq->Wait(&row)); | |||
| RETURN_IF_NOT_OK(cache_miss_.PopFront(row_id, &row)); | |||
| } | |||
| tbl->push_back(std::move(row)); | |||
| } | |||
| @@ -97,6 +102,7 @@ Status CacheMergeOp::WorkerEntry(int32_t worker_id) { | |||
| RETURN_IF_NOT_OK(EofReceived(worker_id)); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) { | |||
| TaskManager::FindMe()->Post(); | |||
| // We will simply pop TensorRow from the stream and insert them into the pool and | |||
| @@ -123,17 +129,27 @@ Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) { | |||
| std::string errMsg = "Expect positive row id: " + std::to_string(row_id); | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| TensorRowRequest *rq = nullptr; | |||
| // Technically number of this row shows up in the cache miss stream is equal to the number | |||
| // of P() call. However the cleaner wants it too. So we need an extra copy. | |||
| TensorRowCacheRequest *rq; | |||
| RETURN_IF_NOT_OK(GetRq(row_id, &rq)); | |||
| rq->WakeUpAny(std::move(row)); | |||
| // Let the cleaner to flush out this row (async) to the cache server. | |||
| RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id)); | |||
| if (rq->GetState() == TensorRowCacheRequest::State::kEmpty) { | |||
| // We will send the request async. But any error we most | |||
| // likely ignore and continue. | |||
| Status rc; | |||
| rc = rq->AsyncSendCacheRequest(cache_client_, row); | |||
| if (rc.IsOk()) { | |||
| RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id)); | |||
| } | |||
| } | |||
| RETURN_IF_NOT_OK(cache_miss_.Add(row_id, std::move(row))); | |||
| } | |||
| } | |||
| RETURN_IF_NOT_OK(cache_missing_stream->GetNextBuffer(&db_ptr, workerId)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheMergeOp::Cleaner() { | |||
| TaskManager::FindMe()->Post(); | |||
| while (true) { | |||
| @@ -142,45 +158,28 @@ Status CacheMergeOp::Cleaner() { | |||
| if (row_id < 0) { | |||
| break; | |||
| } | |||
| TensorRowRequest *rq = nullptr; | |||
| // Locate the cache request | |||
| TensorRowCacheRequest *rq; | |||
| RETURN_IF_NOT_OK(GetRq(row_id, &rq)); | |||
| if (rq->GetState() == TensorRowRequest::State::kClean) { | |||
| // If already flushed, move on to the next one. | |||
| // If already flushed, move on to the next one. | |||
| if (rq->GetState() == TensorRowCacheRequest::State::kClean) { | |||
| continue; | |||
| } | |||
| TensorRow row; | |||
| RETURN_IF_NOT_OK(rq->Release(&row)); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!row.empty(), "Programming error."); | |||
| Status rc = cache_client_->WriteRow(row); | |||
| // Bad rc should not bring down the pipeline | |||
| Status rc = rq->CheckCacheResult(); | |||
| if (rc.IsError()) { | |||
| MS_LOG(WARNING) << "Cache not successful." << rc.ToString(); | |||
| // If interrupt, time to quit. | |||
| if (rc.get_code() == StatusCode::kInterrupted) { | |||
| return Status::OK(); | |||
| } | |||
| MS_LOG(INFO) << "Cache row not successful: " << rc.ToString(); | |||
| // Bad rc should not bring down the pipeline. We will simply continue and | |||
| // change the state back to empty. We don't need a CAS from CLEAN back to EMPTY. | |||
| rq->SetState(TensorRowCacheRequest::State::kEmpty); | |||
| } | |||
| rq->SetState(TensorRowRequest::State::kClean); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheMergeOp::GetRq(row_id_type row_id, CacheMergeOp::TensorRowRequest **out) { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| std::unique_lock<std::mutex> lck(mux_); | |||
| auto it = cache_miss_map_.find(row_id); | |||
| if (it != cache_miss_map_.end()) { | |||
| *out = it->second.GetMutablePointer(); | |||
| } else { | |||
| // We will create a new one. | |||
| auto alloc = Services::GetAllocator<TensorRowRequest>(); | |||
| auto r = cache_miss_map_.emplace(row_id, MemGuard<TensorRowRequest, Allocator<TensorRowRequest>>(alloc)); | |||
| if (r.second) { | |||
| auto &mem = r.first->second; | |||
| RETURN_IF_NOT_OK(mem.allocate(1, row_id)); | |||
| *out = mem.GetMutablePointer(); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Map insert fail."); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheMergeOp::PrepareNodePostAction() { // Run any common code from super class first before adding our own | |||
| // specific logic | |||
| CHECK_FAIL_RETURN_UNEXPECTED(child_.size() == 2, "Incorrect number of children"); | |||
| @@ -199,6 +198,7 @@ Status CacheMergeOp::PrepareNodePostAction() { // Run any common code from supe | |||
| RETURN_IF_NOT_OK(rc); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheMergeOp::ComputeColMap() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(child_[kCacheMissChildIdx] != nullptr, "Cache miss stream empty"); | |||
| if (column_name_id_map().empty()) { | |||
| @@ -207,53 +207,13 @@ Status CacheMergeOp::ComputeColMap() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!column_name_id_map().empty(), "No column map detected"); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheMergeOp::TensorRowRequest::Wait(TensorRow *out) { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| // Block until the missing row is in the pool. | |||
| RETURN_IF_NOT_OK(use_count_.P()); | |||
| std::unique_lock<std::mutex> lck(dq_mux_); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!row_.empty(), "Programming error"); | |||
| *out = std::move(row_.front()); | |||
| row_.pop_front(); | |||
| return Status::OK(); | |||
| } | |||
| void CacheMergeOp::TensorRowRequest::WakeUpAny(TensorRow &&row) { | |||
| std::unique_lock<std::mutex> lck(dq_mux_); | |||
| // Technically number of this row shows up in the cache miss stream is equal to the number | |||
| // of P() call. However the cleaner wants it too. So we need an extra copy. | |||
| if (GetState() == State::kEmpty) { | |||
| // We will do a deep copy | |||
| for (auto &ts : row) { | |||
| std::shared_ptr<Tensor> out_ts; | |||
| Tensor::CreateFromTensor(ts, &out_ts); | |||
| cleaner_copy_.push_back(out_ts); | |||
| } | |||
| cleaner_copy_.setId(row.getId()); | |||
| // Change the state to dirty | |||
| SetState(State::kDirty); | |||
| } | |||
| row_.push_back(std::move(row)); | |||
| // Bump up the use count by 1. This wake up any parallel worker which is waiting | |||
| // for this row. | |||
| use_count_.V(); | |||
| } | |||
| Status CacheMergeOp::TensorRowRequest::Release(TensorRow *out) { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| // We are not holding any mutex here because the cleaner isn't really touching the deque row_. | |||
| // In case we have multiple cleaners and they all see the copy, only one of them will | |||
| // get it. | |||
| auto expected = State::kDirty; | |||
| if (st_.compare_exchange_strong(expected, State::kClean)) { | |||
| *out = std::move(cleaner_copy_); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Builder constructor. Creates the builder object. | |||
| CacheMergeOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) { | |||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||
| build_num_workers_ = cfg->num_parallel_workers(); | |||
| build_op_connector_size_ = cfg->op_connector_size(); | |||
| build_num_cleaners_ = 1; | |||
| build_num_cleaners_ = cfg->num_parallel_workers(); | |||
| } | |||
| // Check if the required parameters are set by the builder. | |||
| @@ -311,5 +271,60 @@ Status CacheMergeOp::EofReceived(int32_t worker_id) { | |||
| MS_LOG(DEBUG) << "Cache merge sending eof"; | |||
| return DatasetOp::EofReceived(worker_id); | |||
| } | |||
| Status CacheMergeOp::GetRq(row_id_type row_id, CacheMergeOp::TensorRowCacheRequest **out) { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| std::unique_lock<std::mutex> lock(mux_); | |||
| auto it = io_request_.find(row_id); | |||
| if (it != io_request_.end()) { | |||
| *out = it->second.GetMutablePointer(); | |||
| } else { | |||
| // We will create a new one. | |||
| auto alloc = Services::GetAllocator<TensorRowCacheRequest>(); | |||
| auto r = io_request_.emplace(row_id, MemGuard<TensorRowCacheRequest, Allocator<TensorRowCacheRequest>>(alloc)); | |||
| if (r.second) { | |||
| auto &mem = r.first->second; | |||
| RETURN_IF_NOT_OK(mem.allocate(1)); | |||
| *out = mem.GetMutablePointer(); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Map insert fail."); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheMergeOp::TensorRowCacheRequest::AsyncSendCacheRequest(const std::shared_ptr<CacheClient> &cc, | |||
| const TensorRow &row) { | |||
| auto expected = State::kEmpty; | |||
| if (st_.compare_exchange_strong(expected, State::kDirty)) { | |||
| // We will do a deep copy but write directly into CacheRequest protobuf or shared memory | |||
| Status rc; | |||
| cleaner_copy_ = | |||
| std::make_shared<CacheRowRequest>(cc->server_connection_id_, cc->cookie(), cc->SupportLocalClient()); | |||
| rc = cleaner_copy_->SerializeCacheRowRequest(cc.get(), row); | |||
| if (rc.IsOk()) { | |||
| // Send the request async. The cleaner will check the return code. | |||
| rc = cc->PushRequest(cleaner_copy_); | |||
| } | |||
| if (rc.IsError()) { | |||
| // Clean up the shared pointer and reset the state back to empty | |||
| cleaner_copy_.reset(); | |||
| st_ = State::kEmpty; | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheMergeOp::TensorRowCacheRequest::CheckCacheResult() { | |||
| auto expected = State::kDirty; | |||
| if (st_.compare_exchange_strong(expected, State::kClean)) { | |||
| // Success or not, we will release the memory. | |||
| // We simply move it out of the structure and let it go out of scope. | |||
| auto cache_request = std::move(cleaner_copy_); | |||
| RETURN_IF_NOT_OK(cache_request->Wait()); | |||
| return Status::OK(); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -16,6 +16,7 @@ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_ | |||
| #include <algorithm> | |||
| #include <atomic> | |||
| #include <deque> | |||
| #include <map> | |||
| @@ -28,6 +29,7 @@ | |||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | |||
| #include "minddata/dataset/engine/dataset_iterator.h" | |||
| #include "minddata/dataset/util/queue.h" | |||
| #include "minddata/dataset/util/queue_map.h" | |||
| #include "minddata/dataset/util/semaphore.h" | |||
| namespace mindspore { | |||
| @@ -36,28 +38,34 @@ namespace dataset { | |||
| /// stream | |||
| class CacheMergeOp : public ParallelOp { | |||
| public: | |||
| // Some handshake structures among the main thread, cleaner threads and parallel op threads. | |||
| class TensorRowRequest { | |||
| // Some handshake structures between CacheMissWorkerEntry and Cleaner | |||
| class TensorRowCacheRequest { | |||
| public: | |||
| enum class State : uint8_t { | |||
| kEmpty = 0, // No row in the deque | |||
| kEmpty = 0, // Initial state. Row hasn't arrived from cache miss stream yet. | |||
| kDirty = 1, // Cleaner hasn't flushed it to the cache server yet. | |||
| kClean = 2 // The row has been flushed already. | |||
| }; | |||
| explicit TensorRowRequest(row_id_type id) : st_(State::kEmpty), use_count_(0) {} | |||
| ~TensorRowRequest() = default; | |||
| TensorRowCacheRequest() : st_(State::kEmpty) {} | |||
| ~TensorRowCacheRequest() = default; | |||
| /// Getter and Setter of the state | |||
| State GetState() const { return st_; } | |||
| void SetState(State newState) { st_ = newState; } | |||
| Status Wait(TensorRow *out); | |||
| void WakeUpAny(TensorRow &&row); | |||
| Status Release(TensorRow *out); | |||
| /// Take a tensor row and send rpc call to the server async | |||
| /// \param cc Cache client of the CacheMergeOp | |||
| /// \param row TensorRow to be sent to the server | |||
| /// \return Status object | |||
| /// \note Thread safe | |||
| Status AsyncSendCacheRequest(const std::shared_ptr<CacheClient> &cc, const TensorRow &row); | |||
| /// \brief We send the row to the server async so the CacheMissWorkerEntry can continue. | |||
| /// It is the cleaner that will check the result. | |||
| /// \return Status object | |||
| Status CheckCacheResult(); | |||
| private: | |||
| std::mutex dq_mux_; | |||
| std::atomic<State> st_; | |||
| Semaphore use_count_; | |||
| std::deque<TensorRow> row_; | |||
| TensorRow cleaner_copy_; | |||
| std::shared_ptr<CacheRowRequest> cleaner_copy_; | |||
| }; | |||
| constexpr static int kCacheHitChildIdx = 0; // Cache hit stream | |||
| @@ -80,6 +88,8 @@ class CacheMergeOp : public ParallelOp { | |||
| /// \return Builder setter method returns reference to the builder. | |||
| Builder &SetNumWorkers(int32_t num_workers) { | |||
| build_num_workers_ = num_workers; | |||
| // Adjust the number of cleaners to match the number of workers | |||
| build_num_cleaners_ = std::max(build_num_cleaners_, build_num_workers_); | |||
| return *this; | |||
| } | |||
| @@ -159,7 +169,6 @@ class CacheMergeOp : public ParallelOp { | |||
| /// \param workerId | |||
| /// \return Status object | |||
| Status CacheMissWorkerEntry(int32_t workerId); | |||
| Status GetRq(row_id_type row_id, TensorRowRequest **); | |||
| /// \brief Base-class override for NodePass pre-visit acceptor | |||
| /// \param[in] p The node to visit | |||
| @@ -188,11 +197,18 @@ class CacheMergeOp : public ParallelOp { | |||
| private: | |||
| std::mutex mux_; | |||
| std::map<row_id_type, MemGuard<TensorRowRequest, Allocator<TensorRowRequest>>> cache_miss_map_; | |||
| QueueMap<row_id_type, TensorRow> cache_miss_; | |||
| std::map<row_id_type, MemGuard<TensorRowCacheRequest, Allocator<TensorRowCacheRequest>>> io_request_; | |||
| std::unique_ptr<Queue<row_id_type>> io_que_; | |||
| std::shared_ptr<CacheClient> cache_client_; | |||
| int32_t num_cleaners_; | |||
| /// \brief Locate the cache request from the io_request_ map | |||
| /// \param row_id | |||
| /// \param out pointer to the cache request | |||
| /// \return Status object | |||
| Status GetRq(row_id_type row_id, TensorRowCacheRequest **out); | |||
| /// \brief These are the entry functions for the cleaner threads. Each cleaner is responsible for | |||
| /// moving cache miss TensorRow into the CacheServer. | |||
| /// \return Status object | |||
| @@ -142,7 +142,7 @@ Status CacheOp::WaitForCachingAllRows() { | |||
| } | |||
| // Get statistics from the server, and if we are not the one to create the cache, | |||
| // wait until the state changed from build phase to fetch base. | |||
| CacheClient::ServiceStat stat{}; | |||
| CacheServiceStat stat{}; | |||
| bool BuildPhaseDone = true; | |||
| do { | |||
| RETURN_IF_NOT_OK(cache_client_->GetStat(&stat)); | |||
| @@ -157,6 +157,7 @@ Status CacheOp::WaitForCachingAllRows() { | |||
| MS_LOG(INFO) << "Number of rows cached: " << num_rows_; | |||
| MS_LOG(INFO) << "Number of rows cached in memory : " << stat.num_mem_cached; | |||
| MS_LOG(INFO) << "Number of rows spilled to disk : " << stat.num_disk_cached; | |||
| MS_LOG(INFO) << "Average cache size : " << stat.avg_cache_sz; | |||
| // Now all rows are cached and we have done a sync point check up. Next phase is | |||
| // is pick up fetch input from sampler and pass up to the caller. | |||
| RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); | |||
| @@ -392,6 +392,13 @@ uint32_t DatasetOp::GenerateCRC(const std::shared_ptr<DatasetOp> &op) { | |||
| ss_str = std::regex_replace(ss_str, std::regex("Num workers.*\n"), ""); | |||
| ss_str = std::regex_replace(ss_str, std::regex("\\[workers.*\\]"), ""); | |||
| // Filter out tcp/ip information | |||
| ss_str = std::regex_replace(ss_str, std::regex("Hostname.*\n"), ""); | |||
| ss_str = std::regex_replace(ss_str, std::regex("Port.*\n"), ""); | |||
| ss_str = std::regex_replace(ss_str, std::regex("Number of rpc workers.*\n"), ""); | |||
| ss_str = std::regex_replace(ss_str, std::regex("Prefetch size.*\n"), ""); | |||
| ss_str = std::regex_replace(ss_str, std::regex("Local client support.*\n"), ""); | |||
| // Filter out Number of rows when generating the check sum | |||
| ss_str = std::regex_replace(ss_str, std::regex("Number of rows.*\n"), ""); | |||
| @@ -73,6 +73,7 @@ enum class StatusCode : char { | |||
| kProfilingError = 10, | |||
| kBoundingBoxOutOfBounds = 11, | |||
| kBoundingBoxInvalidShape = 12, | |||
| kSyntaxError = 13, | |||
| // Make this error code the last one. Add new error code above it. | |||
| kUnexpectedError = 127 | |||
| }; | |||
| @@ -168,9 +168,9 @@ class MemGuard { | |||
| size_t GetSizeInBytes() const { return n_ * sizeof(T); } | |||
| private: | |||
| size_t n_; | |||
| allocator alloc_; | |||
| std::unique_ptr<T[]> ptr_; | |||
| size_t n_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -27,20 +27,20 @@ | |||
| #define ARENA_WALL_OVERHEAD_SZ 32 | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // This is a memory arena based on a treap data structure. | |||
| // The constructor of the Arena takes the size of the initial memory size (in MB). | |||
| // Internally we divide the memory into multiple blocks. Each block is 64 bytes. | |||
| // The treap contains all the free blocks with the relative memory address as key | |||
| // and the size of the block as priority. | |||
| // | |||
| // Initially the treap has only one root which is the whole memory piece. | |||
| // | |||
| // For memory suballocation, we pop the root node of the treap which contains the largest free block. | |||
| // We allocate what we need and return the rest back to the treap. We search for the first fit instead | |||
| // of the best fit so to give us a constant time in memory allocation. | |||
| // | |||
| // When a block of memory is freed. It is joined with the blocks before and after (if they are available) to | |||
| // form a bigger block. | |||
| /// This is a memory arena based on a treap data structure. | |||
| /// The constructor of the Arena takes the size of the initial memory size (in MB). | |||
| /// Internally we divide the memory into multiple blocks. Each block is 64 bytes. | |||
| /// The treap contains all the free blocks with the relative memory address as key | |||
| /// and the size of the block as priority. | |||
| /// | |||
| /// Initially the treap has only one root which is the whole memory piece. | |||
| /// | |||
| /// For memory suballocation, we pop the root node of the treap which contains the largest free block. | |||
| /// We allocate what we need and return the rest back to the treap. We search for the first fit instead | |||
| /// of the best fit so to give us a constant time in memory allocation. | |||
| /// | |||
| /// When a block of memory is freed. It is joined with the blocks before and after (if they are available) to | |||
| /// form a bigger block. | |||
| class Arena : public MemoryPool { | |||
| public: | |||
| Arena(const Arena &) = delete; | |||
| @@ -78,7 +78,7 @@ class Arena : public MemoryPool { | |||
| static Status CreateArena(std::shared_ptr<Arena> *p_ba, size_t val_in_MB = 4096); | |||
| private: | |||
| protected: | |||
| std::mutex mux_; | |||
| Treap<uint64_t, uint64_t> tr_; | |||
| void *ptr_; | |||
| @@ -140,13 +140,22 @@ Path CachePool::GetSpillPath() const { | |||
| } | |||
| CachePool::CacheStat CachePool::GetStat() const { | |||
| CacheStat cs{0}; | |||
| int64_t total_sz = 0; | |||
| for (auto &it : *tree_) { | |||
| total_sz += it.sz; | |||
| if (it.ptr != nullptr) { | |||
| ++cs.num_mem_cached; | |||
| } else { | |||
| ++cs.num_disk_cached; | |||
| } | |||
| } | |||
| if (total_sz > 0) { | |||
| // integer arithmetic. NO need to cast to float or double. | |||
| cs.average_cache_sz = total_sz / (cs.num_disk_cached + cs.num_mem_cached); | |||
| if (cs.average_cache_sz == 0) { | |||
| cs.average_cache_sz = 1; | |||
| } | |||
| } | |||
| return cs; | |||
| } | |||
| Status CachePool::Spill(CachePool::DataLocator *dl) { | |||
| @@ -82,6 +82,7 @@ class CachePool : public Service { | |||
| struct CacheStat { | |||
| int64_t num_mem_cached; | |||
| int64_t num_disk_cached; | |||
| int64_t average_cache_sz; | |||
| }; | |||
| /// \brief Constructor | |||
| @@ -0,0 +1,127 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_MAP_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_MAP_H_ | |||
| #include <deque> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <mutex> | |||
| #include "minddata/dataset/util/allocator.h" | |||
| #include "minddata/dataset/util/semaphore.h" | |||
| #include "minddata/dataset/util/services.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| template <typename K, typename T> | |||
| /// \brief QueueMap is like a Queue but instead of there is a map of deque<T>. | |||
| /// Consumer will block if the corresponding deque is empty. | |||
| /// Producer can add an element of type T with key of type K to the map and | |||
| /// wake up any waiting consumer. | |||
| /// \tparam K key type | |||
| /// \tparam T payload of the map | |||
| class QueueMap { | |||
| public: | |||
| using key_type = K; | |||
| using value_type = T; | |||
| QueueMap() = default; | |||
| virtual ~QueueMap() = default; | |||
| /// Add an element <key, T> to the map and wake up any consumer that is waiting | |||
| /// \param key | |||
| /// \param payload | |||
| /// \return Status object | |||
| virtual Status Add(key_type key, T &&payload) { | |||
| RequestQueue *rq = nullptr; | |||
| RETURN_IF_NOT_OK(GetRq(key, &rq)); | |||
| RETURN_IF_NOT_OK(rq->WakeUpAny(std::move(payload))); | |||
| return Status::OK(); | |||
| } | |||
| /// Pop the front of the deque with key. Block if the deque is empty. | |||
| virtual Status PopFront(key_type key, T *out) { | |||
| RequestQueue *rq = nullptr; | |||
| RETURN_IF_NOT_OK(GetRq(key, &rq)); | |||
| RETURN_IF_NOT_OK(rq->Wait(out)); | |||
| return Status::OK(); | |||
| } | |||
| protected: | |||
| /// This is a handshake structure between producer and consumer | |||
| class RequestQueue { | |||
| public: | |||
| RequestQueue() : use_count_(0) {} | |||
| ~RequestQueue() = default; | |||
| Status Wait(T *out) { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| // Block until the missing row is in the pool. | |||
| RETURN_IF_NOT_OK(use_count_.P()); | |||
| std::unique_lock<std::mutex> lck(dq_mux_); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!row_.empty(), "Programming error"); | |||
| *out = std::move(row_.front()); | |||
| row_.pop_front(); | |||
| return Status::OK(); | |||
| } | |||
| Status WakeUpAny(T &&row) { | |||
| std::unique_lock<std::mutex> lck(dq_mux_); | |||
| row_.push_back(std::move(row)); | |||
| // Bump up the use count by 1. This wake up any parallel worker which is waiting | |||
| // for this row. | |||
| use_count_.V(); | |||
| return Status::OK(); | |||
| } | |||
| private: | |||
| std::mutex dq_mux_; | |||
| Semaphore use_count_; | |||
| std::deque<T> row_; | |||
| }; | |||
| /// Create or locate an element with matching key | |||
| /// \param key | |||
| /// \param out | |||
| /// \return Status object | |||
| Status GetRq(key_type key, RequestQueue **out) { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| std::unique_lock<std::mutex> lck(mux_); | |||
| auto it = all_.find(key); | |||
| if (it != all_.end()) { | |||
| *out = it->second.GetMutablePointer(); | |||
| } else { | |||
| // We will create a new one. | |||
| auto alloc = Services::GetAllocator<RequestQueue>(); | |||
| auto r = all_.emplace(key, MemGuard<RequestQueue, Allocator<RequestQueue>>(alloc)); | |||
| if (r.second) { | |||
| auto &mem = r.first->second; | |||
| RETURN_IF_NOT_OK(mem.allocate(1)); | |||
| *out = mem.GetMutablePointer(); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Map insert fail."); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| private: | |||
| std::mutex mux_; | |||
| std::map<K, MemGuard<RequestQueue, Allocator<RequestQueue>>> all_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_MAP_H_ | |||
| @@ -22,7 +22,6 @@ | |||
| #include <stdlib.h> | |||
| #endif | |||
| #include <unistd.h> | |||
| #include "minddata/dataset/engine/cache/cache_server.h" | |||
| #include "minddata/dataset/util/circular_pool.h" | |||
| #include "minddata/dataset/util/random.h" | |||
| #include "minddata/dataset/util/task_manager.h" | |||
| @@ -59,35 +58,15 @@ std::string Services::GetUniqueID() { | |||
| return std::string(buffer, UNIQUEID_LEN); | |||
| } | |||
| TaskManager &Services::getTaskMgrInstance() { | |||
| Services &sm = GetInstance(); | |||
| return *(static_cast<TaskManager *>(sm.sa_[kSlotTaskMgr_])); | |||
| } | |||
| CacheServer &Services::getCacheServer() { | |||
| Services &sm = GetInstance(); | |||
| return *(static_cast<CacheServer *>(sm.sa_[kSlotCacheMgr_])); | |||
| } | |||
| Status Services::CreateAllInstances() { | |||
| // In order, TaskMgr, BufferMgr | |||
| Status rc; | |||
| sa_[kSlotTaskMgr_] = new (&rc, pool_) TaskManager(); | |||
| RETURN_IF_NOT_OK(rc); | |||
| rc = sa_[kSlotTaskMgr_]->ServiceStart(); | |||
| RETURN_IF_NOT_OK(rc); | |||
| // TODO(jesse) : Get the parameters from config file. Right now spill to /tmp and spawn 3 workers | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| sa_[kSlotCacheMgr_] = new (&rc, pool_) CacheServer("/tmp", 3); | |||
| RETURN_IF_NOT_OK(rc); | |||
| rc = sa_[kSlotCacheMgr_]->ServiceStart(); | |||
| #else | |||
| sa_[kSlotCacheMgr_] = nullptr; | |||
| #endif | |||
| return rc; | |||
| // First one is always the TaskManager | |||
| RETURN_IF_NOT_OK(TaskManager::CreateInstance()); | |||
| TaskManager &tm = TaskManager::GetInstance(); | |||
| RETURN_IF_NOT_OK(tm.ServiceStart()); | |||
| return Status::OK(); | |||
| } | |||
| Services::Services() : pool_(nullptr), sa_{nullptr} { | |||
| Services::Services() : pool_(nullptr) { | |||
| Status rc = CircularPool::CreateCircularPool(&pool_, -1, 16, true); // each arena 16M | |||
| if (rc.IsError()) { | |||
| std::terminate(); | |||
| @@ -95,22 +74,11 @@ Services::Services() : pool_(nullptr), sa_{nullptr} { | |||
| } | |||
| Services::~Services() noexcept { | |||
| try { | |||
| // In reverse order | |||
| CacheServer *cs = static_cast<CacheServer *>(sa_[kSlotCacheMgr_]); | |||
| if (cs != nullptr) { | |||
| (void)cs->ServiceStop(); | |||
| cs->~CacheServer(); | |||
| pool_->Deallocate(cs); | |||
| } | |||
| TaskManager *tm = static_cast<TaskManager *>(sa_[kSlotTaskMgr_]); | |||
| if (tm != nullptr) { | |||
| (void)tm->ServiceStop(); | |||
| tm->~TaskManager(); | |||
| pool_->Deallocate(tm); | |||
| } | |||
| } catch (const std::exception &e) { | |||
| // Do nothing. | |||
| // Shutdown in reverse order. | |||
| auto n = hook_.size(); | |||
| while (n > 0) { | |||
| hook_.pop_back(); | |||
| n = hook_.size(); | |||
| } | |||
| } | |||
| } // namespace dataset | |||
| @@ -16,9 +16,11 @@ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_SERVICES_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_SERVICES_H_ | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <mutex> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "minddata/dataset/util/memory_pool.h" | |||
| #include "minddata/dataset/util/allocator.h" | |||
| #include "minddata/dataset/util/service.h" | |||
| @@ -27,7 +29,7 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class TaskManager; | |||
| class CacheServer; | |||
| class Services { | |||
| public: | |||
| static Status CreateInstance() { | |||
| @@ -59,10 +61,6 @@ class Services { | |||
| ~Services() noexcept; | |||
| static TaskManager &getTaskMgrInstance(); | |||
| static CacheServer &getCacheServer(); | |||
| std::shared_ptr<MemoryPool> GetServiceMemPool() { return pool_; } | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| @@ -80,19 +78,29 @@ class Services { | |||
| return Allocator<T>(Services::GetInstance().GetServiceMemPool()); | |||
| } | |||
| /// \brief Add a new service to the start up list. | |||
| /// \tparam T Class that implements Service | |||
| /// \return Status object and where the service is located in the hook_ list | |||
| template <typename T, typename... Args> | |||
| Status AddHook(T **out, Args &&... args) { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| try { | |||
| (*out) = new T(std::forward<Args>(args)...); | |||
| std::unique_ptr<T> svc(*out); | |||
| hook_.push_back(std::move(svc)); | |||
| } catch (const std::bad_alloc &e) { | |||
| return Status(StatusCode::kOutOfMemory); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| private: | |||
| static std::once_flag init_instance_flag_; | |||
| static std::unique_ptr<Services> instance_; | |||
| // A small pool used for small objects that last until the | |||
| // Services Manager shuts down. Used by all sub-services. | |||
| std::shared_ptr<MemoryPool> pool_; | |||
| // We use pointers here instead of unique_ptr because we | |||
| // want to have ultimate control on the order of | |||
| // construction and destruction. | |||
| static constexpr int kSlotTaskMgr_ = 0; | |||
| static constexpr int kSlotCacheMgr_ = 1; | |||
| static constexpr int kNumServices_ = 2; | |||
| Service *sa_[kNumServices_]; | |||
| std::vector<std::unique_ptr<Service>> hook_; | |||
| Services(); | |||
| @@ -86,6 +86,7 @@ class ReadableSlice { | |||
| class WritableSlice : public ReadableSlice { | |||
| public: | |||
| friend class StorageContainer; | |||
| friend class CacheService; | |||
| /// \brief Default constructor | |||
| WritableSlice() : ReadableSlice(), mutable_data_(nullptr) {} | |||
| /// \brief This form of a constructor takes a pointer and its size. | |||
| @@ -48,6 +48,9 @@ std::string CodeAsString(const StatusCode c) { | |||
| case StatusCode::kProfilingError: | |||
| s = "Error encountered while profiling"; | |||
| break; | |||
| case StatusCode::kSyntaxError: | |||
| s = "Syntax error"; | |||
| break; | |||
| case StatusCode::kUnexpectedError: | |||
| default: | |||
| s = "Unexpected error"; | |||
| @@ -80,6 +80,7 @@ enum class StatusCode : char { | |||
| kProfilingError = 10, | |||
| kBoundingBoxOutOfBounds = 11, | |||
| kBoundingBoxInvalidShape = 12, | |||
| kSyntaxError = 13, | |||
| // Make this error code the last one. Add new error code above it. | |||
| kUnexpectedError = 127 | |||
| }; | |||
| @@ -21,6 +21,8 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| TaskManager *TaskManager::instance_ = nullptr; | |||
| std::once_flag TaskManager::init_instance_flag_; | |||
| // This takes the same parameter as Task constructor. | |||
| Status TaskManager::CreateAsyncTask(const std::string &my_name, const std::function<Status()> &f, TaskGroup *vg, | |||
| Task **task) { | |||
| @@ -54,7 +54,16 @@ class TaskManager : public Service { | |||
| TaskManager &operator=(const TaskManager &) = delete; | |||
| static TaskManager &GetInstance() noexcept { return Services::getTaskMgrInstance(); } | |||
| static Status CreateInstance() { | |||
| std::call_once(init_instance_flag_, [&]() -> Status { | |||
| auto &svcManager = Services::GetInstance(); | |||
| RETURN_IF_NOT_OK(svcManager.AddHook(&instance_)); | |||
| return Status::OK(); | |||
| }); | |||
| return Status::OK(); | |||
| } | |||
| static TaskManager &GetInstance() noexcept { return *instance_; } | |||
| Status DoServiceStart() override; | |||
| @@ -96,6 +105,8 @@ class TaskManager : public Service { | |||
| Status WatchDog(); | |||
| private: | |||
| static std::once_flag init_instance_flag_; | |||
| static TaskManager *instance_; | |||
| RWLock lru_lock_; | |||
| SpinLock free_lock_; | |||
| SpinLock tg_lock_; | |||
| @@ -25,15 +25,22 @@ class DatasetCache: | |||
| A client to interface with tensor caching service | |||
| """ | |||
| def __init__(self, session_id=None, size=0, spilling=False): | |||
| def __init__(self, session_id=None, size=0, spilling=False, port=50052, prefetch_size=20): | |||
| check_uint32(session_id, "session_id") | |||
| check_uint64(size, "size") | |||
| type_check(spilling, (bool,), "spilling") | |||
| check_uint32(port, "port") | |||
| check_uint32(prefetch_size, "prefetch size") | |||
| self.session_id = session_id | |||
| self.size = size | |||
| self.spilling = spilling | |||
| self.cache_client = CacheClient(session_id, size, spilling) | |||
| self.port = port | |||
| self.prefetch_size = prefetch_size | |||
| self.cache_client = CacheClient(session_id, size, spilling, port, prefetch_size) | |||
| def GetStat(self): | |||
| return self.cache_client.GetStat() | |||
| def __deepcopy__(self, memodict): | |||
| if id(self) in memodict: | |||
| @@ -44,5 +51,7 @@ class DatasetCache: | |||
| new_cache.session_id = copy.deepcopy(self.session_id, memodict) | |||
| new_cache.spilling = copy.deepcopy(self.spilling, memodict) | |||
| new_cache.size = copy.deepcopy(self.size, memodict) | |||
| new_cache.port = copy.deepcopy(self.port, memodict) | |||
| new_cache.prefetch_size = copy.deepcopy(self.prefetch_size, memodict) | |||
| new_cache.cache_client = self.cache_client | |||
| return new_cache | |||
| @@ -43,13 +43,18 @@ class MindDataTestCacheOp : public UT::DatasetOpTesting { | |||
| } | |||
| }; | |||
| TEST_F(MindDataTestCacheOp, TestCacheServer) { | |||
| TEST_F(MindDataTestCacheOp, DISABLED_TestCacheServer) { | |||
| Status rc; | |||
| CacheClient myClient(1, 0, true); // use arbitrary session of 1, size of 0, spilling is true | |||
| CacheClient::Builder builder; | |||
| // use arbitrary session of 1, size of 0, spilling// is true | |||
| builder.SetSessionId(1).SetCacheMemSz(0).SetSpill(true); | |||
| std::shared_ptr<CacheClient> myClient; | |||
| rc = builder.Build(&myClient); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // cksum value of 1 for CreateCache here...normally you do not directly create a cache and the cksum arg is generated. | |||
| rc = myClient.CreateCache(1, true); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| std::cout << myClient << std::endl; | |||
| rc = myClient->CreateCache(1, true); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| std::cout << *myClient << std::endl; | |||
| // Create a schema using the C api's | |||
| int32_t rank = 0; // not used | |||
| @@ -68,11 +73,11 @@ TEST_F(MindDataTestCacheOp, TestCacheServer) { | |||
| std::unordered_map<std::string, int32_t> map; | |||
| rc = testSchema->GetColumnNameMap(&map); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // Test the CacheSchema api | |||
| rc = myClient.CacheSchema(map); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myClient->CacheSchema(map); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // Create a tensor, take a snapshot and restore it back, and compare. | |||
| std::shared_ptr<Tensor> t; | |||
| @@ -88,48 +93,54 @@ TEST_F(MindDataTestCacheOp, TestCacheServer) { | |||
| TensorRow row; | |||
| row.push_back(t); | |||
| int64_t row_id; | |||
| rc = myClient.WriteRow(row, &row_id); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myClient->WriteRow(row, &row_id); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // Switch off build phase. | |||
| rc = myClient.BuildPhaseDone(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myClient->BuildPhaseDone(); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // Now restore from cache. | |||
| row.clear(); | |||
| rc = myClient.GetRows({row_id}, &tbl); | |||
| rc = myClient->GetRows({row_id}, &tbl); | |||
| row = tbl.front(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| auto r = row.front(); | |||
| std::cout << *r << std::endl; | |||
| // Compare | |||
| bool cmp = (*t == *r); | |||
| EXPECT_TRUE(cmp); | |||
| ASSERT_TRUE(cmp); | |||
| // Get back the schema and verify | |||
| std::unordered_map<std::string, int32_t> map_out; | |||
| rc = myClient.FetchSchema(&map_out); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myClient->FetchSchema(&map_out); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| cmp = (map_out == map); | |||
| EXPECT_TRUE(cmp); | |||
| ASSERT_TRUE(cmp); | |||
| // Test Purge and Destroy | |||
| rc = myClient.PurgeCache(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myClient.DestroyCache(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myClient->PurgeCache(); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myClient->DestroyCache(); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| } | |||
| TEST_F(MindDataTestCacheOp, TestConcurrencyRequest) { | |||
| TEST_F(MindDataTestCacheOp, DISABLED_TestConcurrencyRequest) { | |||
| // Clear the rc of the master thread if any | |||
| (void)TaskManager::GetMasterThreadRc(); | |||
| TaskGroup vg; | |||
| Status rc; | |||
| CacheClient myClient(1, 1, true); // use arbitrary session of 1, size 1, spilling is true | |||
| // use arbitrary session of 1, size 1, spilling is true | |||
| CacheClient::Builder builder; | |||
| // use arbitrary session of 1, size of 0, spilling// is true | |||
| builder.SetSessionId(1).SetCacheMemSz(1).SetSpill(true); | |||
| std::shared_ptr<CacheClient> myClient; | |||
| rc = builder.Build(&myClient); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // cksum value of 1 for CreateCache here...normally you do not directly create a cache and the cksum arg is generated. | |||
| rc = myClient.CreateCache(1, true); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| std::cout << myClient << std::endl; | |||
| rc = myClient->CreateCache(1, true); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| std::cout << *myClient << std::endl; | |||
| std::shared_ptr<Tensor> t; | |||
| Tensor::CreateEmpty(TensorShape({2, 3}), DataType(DataType::DE_UINT64), &t); | |||
| t->SetItemAt<uint64_t>({0, 0}, 1); | |||
| @@ -146,19 +157,19 @@ TEST_F(MindDataTestCacheOp, TestConcurrencyRequest) { | |||
| Status vg_rc = vg.CreateAsyncTask("Test agent", [&myClient, &row]() -> Status { | |||
| TaskManager::FindMe()->Post(); | |||
| for (auto i = 0; i < 500; i++) { | |||
| RETURN_IF_NOT_OK(myClient.WriteRow(row)); | |||
| RETURN_IF_NOT_OK(myClient->WriteRow(row)); | |||
| } | |||
| return Status::OK(); | |||
| }); | |||
| EXPECT_TRUE(vg_rc.IsOk()); | |||
| ASSERT_TRUE(vg_rc.IsOk()); | |||
| } | |||
| ASSERT_TRUE(vg.join_all().IsOk()); | |||
| ASSERT_TRUE(vg.GetTaskErrorIfAny().IsOk()); | |||
| rc = myClient.BuildPhaseDone(); | |||
| rc = myClient->BuildPhaseDone(); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // Get statistics from the server. | |||
| CacheClient::ServiceStat stat{}; | |||
| rc = myClient.GetStat(&stat); | |||
| CacheServiceStat stat{}; | |||
| rc = myClient->GetStat(&stat); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| std::cout << stat.min_row_id << ":" << stat.max_row_id << ":" << stat.num_mem_cached << ":" << stat.num_disk_cached | |||
| << "\n"; | |||
| @@ -168,15 +179,15 @@ TEST_F(MindDataTestCacheOp, TestConcurrencyRequest) { | |||
| for (auto i = stat.min_row_id; i <= stat.max_row_id; ++i) { | |||
| tbl.clear(); | |||
| row.clear(); | |||
| rc = myClient.GetRows({i}, &tbl); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myClient->GetRows({i}, &tbl); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| row = tbl.front(); | |||
| auto r = row.front(); | |||
| bool cmp = (*t == *r); | |||
| EXPECT_TRUE(cmp); | |||
| ASSERT_TRUE(cmp); | |||
| } | |||
| rc = myClient.DestroyCache(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myClient->DestroyCache(); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| } | |||
| // Simple test with a repeated cache op over random data producer | |||
| @@ -187,7 +198,7 @@ TEST_F(MindDataTestCacheOp, TestConcurrencyRequest) { | |||
| // | | |||
| // RandomDataOp | |||
| // | |||
| TEST_F(MindDataTestCacheOp, TestRandomDataCache1) { | |||
| TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) { | |||
| Status rc; | |||
| int32_t rank = 0; // not used | |||
| MS_LOG(INFO) << "UT test TestRandomDataCache1"; | |||
| @@ -218,13 +229,18 @@ TEST_F(MindDataTestCacheOp, TestRandomDataCache1) { | |||
| .SetDataSchema(std::move(testSchema)) | |||
| .SetTotalRows(50) // 50 samples for now | |||
| .Build(&myRandomDataOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssociateNode(myRandomDataOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // CacheOp | |||
| // size of 0, spilling is true | |||
| std::shared_ptr<CacheClient> myClient = std::make_shared<CacheClient>(1, 0, true); | |||
| CacheClient::Builder builder; | |||
| // use arbitrary session of 1, size of 0, spilling// is true | |||
| builder.SetSessionId(1).SetCacheMemSz(0).SetSpill(true); | |||
| std::shared_ptr<CacheClient> myClient; | |||
| rc = builder.Build(&myClient); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| std::shared_ptr<CacheOp> myCacheOp; | |||
| int64_t num_samples = 0; | |||
| @@ -236,29 +252,29 @@ TEST_F(MindDataTestCacheOp, TestRandomDataCache1) { | |||
| .SetRowsPerBuffer(4) | |||
| .SetSampler(std::move(seq_sampler)) | |||
| .Build(&myCacheOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssociateNode(myCacheOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // RepeatOp | |||
| uint32_t numRepeats = 4; | |||
| std::shared_ptr<RepeatOp> myRepeatOp; | |||
| rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssociateNode(myRepeatOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // Assign tree relations and root | |||
| rc = myRepeatOp->AddChild(myCacheOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myCacheOp->AddChild(myRandomDataOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssignRoot(myRepeatOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| MS_LOG(INFO) << "Launching tree and begin iteration"; | |||
| rc = myTree->Prepare(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // quick check to see what tree looks like | |||
| std::ostringstream ss; | |||
| @@ -268,24 +284,24 @@ TEST_F(MindDataTestCacheOp, TestRandomDataCache1) { | |||
| std::cout << *myClient << std::endl; | |||
| rc = myTree->Launch(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // Start the loop of reading tensors from our pipeline | |||
| DatasetIterator dI(myTree); | |||
| TensorRow tensorList; | |||
| rc = dI.FetchNextTensorRow(&tensorList); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| int rowCount = 0; | |||
| while (!tensorList.empty()) { | |||
| // Don't display these rows, just count them | |||
| MS_LOG(INFO) << "Row fetched #: " << rowCount; | |||
| rc = dI.FetchNextTensorRow(&tensorList); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rowCount++; | |||
| } | |||
| ASSERT_EQ(rowCount, 200); | |||
| rc = myClient->DestroyCache(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| } | |||
| //// Simple test with a repeated cache op over random data producer. | |||
| @@ -297,7 +313,7 @@ TEST_F(MindDataTestCacheOp, TestRandomDataCache1) { | |||
| //// | | |||
| //// RandomDataOp | |||
| //// | |||
| TEST_F(MindDataTestCacheOp, TestRandomDataCacheSpill) { | |||
| TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) { | |||
| Status rc; | |||
| int32_t rank = 0; // not used | |||
| MS_LOG(INFO) << "UT test TestRandomDataCacheSpill"; | |||
| @@ -328,15 +344,20 @@ TEST_F(MindDataTestCacheOp, TestRandomDataCacheSpill) { | |||
| .SetDataSchema(std::move(testSchema)) | |||
| .SetTotalRows(10) | |||
| .Build(&myRandomDataOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssociateNode(myRandomDataOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // CacheOp | |||
| int64_t num_samples = 0; | |||
| int64_t start_index = 0; | |||
| auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | |||
| std::shared_ptr<CacheClient> myClient = std::make_shared<CacheClient>(1, 4, true); | |||
| CacheClient::Builder builder; | |||
| // use arbitrary session of 1, size of 0, spilling// is true | |||
| builder.SetSessionId(1).SetCacheMemSz(4).SetSpill(true); | |||
| std::shared_ptr<CacheClient> myClient; | |||
| rc = builder.Build(&myClient); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| std::shared_ptr<CacheOp> myCacheOp; | |||
| rc = CacheOp::Builder() | |||
| .SetNumWorkers(4) | |||
| @@ -344,60 +365,65 @@ TEST_F(MindDataTestCacheOp, TestRandomDataCacheSpill) { | |||
| .SetRowsPerBuffer(3) | |||
| .SetSampler(std::move(seq_sampler)) | |||
| .Build(&myCacheOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssociateNode(myCacheOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // RepeatOp | |||
| uint32_t numRepeats = 4; | |||
| std::shared_ptr<RepeatOp> myRepeatOp; | |||
| rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssociateNode(myRepeatOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // Assign tree relations and root | |||
| rc = myRepeatOp->AddChild(myCacheOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myCacheOp->AddChild(myRandomDataOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssignRoot(myRepeatOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| MS_LOG(INFO) << "Launching tree and begin iteration"; | |||
| rc = myTree->Prepare(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| std::cout << *myClient << std::endl; | |||
| rc = myTree->Launch(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // Start the loop of reading tensors from our pipeline | |||
| DatasetIterator dI(myTree); | |||
| TensorRow tensorList; | |||
| rc = dI.FetchNextTensorRow(&tensorList); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| int rowCount = 0; | |||
| while (!tensorList.empty()) { | |||
| // Don't display these rows, just count them | |||
| MS_LOG(INFO) << "Row fetched #: " << rowCount; | |||
| rc = dI.FetchNextTensorRow(&tensorList); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rowCount++; | |||
| } | |||
| ASSERT_EQ(rowCount, 40); | |||
| rc = myClient->DestroyCache(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| } | |||
| TEST_F(MindDataTestCacheOp, TestImageFolderCacheMerge) { | |||
| TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) { | |||
| Status rc; | |||
| int64_t num_samples = 0; | |||
| int64_t start_index = 0; | |||
| auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | |||
| std::shared_ptr<CacheClient> myClient = std::make_shared<CacheClient>(1, 0, true); | |||
| CacheClient::Builder ccbuilder; | |||
| // use arbitrary session of 1, size of 0, spilling// is true | |||
| ccbuilder.SetSessionId(1).SetCacheMemSz(0).SetSpill(true); | |||
| std::shared_ptr<CacheClient> myClient; | |||
| rc = ccbuilder.Build(&myClient); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // In a mappable dataset, it uses a complex interactions of cache lookup op and cache merge op. | |||
| // Rather than manually build this, the way to do it is to choose the position of the cache in the tree by | |||
| @@ -417,44 +443,44 @@ TEST_F(MindDataTestCacheOp, TestImageFolderCacheMerge) { | |||
| .SetRecursive(true) | |||
| .SetImageFolderDir(datasets_root_path_ + "/testPK/data"); | |||
| rc = builder.Build(&so); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // RepeatOp | |||
| uint32_t numRepeats = 4; | |||
| std::shared_ptr<RepeatOp> myRepeatOp; | |||
| rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| auto myTree = std::make_shared<ExecutionTree>(); | |||
| rc = myTree->AssociateNode(so); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssociateNode(myCacheOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssociateNode(myRepeatOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssignRoot(myRepeatOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myRepeatOp->AddChild(myCacheOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myCacheOp->AddChild(so); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myTree->Prepare(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myTree->Launch(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // Start the loop of reading tensors from our pipeline | |||
| DatasetIterator dI(myTree); | |||
| TensorRow tensorList; | |||
| rc = dI.FetchNextTensorRow(&tensorList); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| int rowCount = 0; | |||
| while (!tensorList.empty()) { | |||
| rc = dI.FetchNextTensorRow(&tensorList); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| if (rc.IsError()) { | |||
| std::cout << rc << std::endl; | |||
| break; | |||
| @@ -464,7 +490,7 @@ TEST_F(MindDataTestCacheOp, TestImageFolderCacheMerge) { | |||
| ASSERT_EQ(rowCount, 176); | |||
| std::cout << "Row count : " << rowCount << std::endl; | |||
| rc = myClient->DestroyCache(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| } | |||
| //// Simple test with a repeated cache op over random data producer. | |||
| @@ -480,7 +506,7 @@ TEST_F(MindDataTestCacheOp, TestImageFolderCacheMerge) { | |||
| //// | | |||
| //// RandomDataOp | |||
| //// | |||
| TEST_F(MindDataTestCacheOp, TestCacheInheritSampler) { | |||
| TEST_F(MindDataTestCacheOp, DISABLED_TestCacheInheritSampler) { | |||
| Status rc; | |||
| int32_t rank = 0; // not used | |||
| MS_LOG(INFO) << "UT test TestCacheInheritSampler"; | |||
| @@ -517,57 +543,62 @@ TEST_F(MindDataTestCacheOp, TestCacheInheritSampler) { | |||
| .SetTotalRows(10) | |||
| .SetSampler(std::move(seq_sampler)) | |||
| .Build(&myRandomDataOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssociateNode(myRandomDataOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // CacheOp | |||
| std::shared_ptr<CacheClient> myClient = std::make_shared<CacheClient>(1, 4, true); | |||
| CacheClient::Builder ccbuilder; | |||
| // use arbitrary session of 1, size of 0, spilling// is true | |||
| ccbuilder.SetSessionId(1).SetCacheMemSz(4).SetSpill(true); | |||
| std::shared_ptr<CacheClient> myClient; | |||
| rc = ccbuilder.Build(&myClient); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| std::shared_ptr<CacheOp> myCacheOp; | |||
| rc = CacheOp::Builder().SetNumWorkers(4).SetClient(myClient).SetRowsPerBuffer(3).Build(&myCacheOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssociateNode(myCacheOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // RepeatOp | |||
| uint32_t numRepeats = 4; | |||
| std::shared_ptr<RepeatOp> myRepeatOp; | |||
| rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssociateNode(myRepeatOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // Assign tree relations and root | |||
| rc = myRepeatOp->AddChild(myCacheOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myCacheOp->AddChild(myRandomDataOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssignRoot(myRepeatOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| MS_LOG(INFO) << "Launching tree and begin iteration"; | |||
| rc = myTree->Prepare(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| std::cout << *myClient << std::endl; | |||
| rc = myTree->Launch(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // Start the loop of reading tensors from our pipeline | |||
| DatasetIterator dI(myTree); | |||
| TensorRow tensorList; | |||
| rc = dI.FetchNextTensorRow(&tensorList); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| int rowCount = 0; | |||
| while (!tensorList.empty()) { | |||
| // Don't display these rows, just count them | |||
| MS_LOG(INFO) << "Row fetched #: " << rowCount; | |||
| rc = dI.FetchNextTensorRow(&tensorList); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rowCount++; | |||
| } | |||
| ASSERT_EQ(rowCount, 40); | |||
| rc = myClient->DestroyCache(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| } | |||
| @@ -15,6 +15,8 @@ | |||
| """ | |||
| Testing cache operator with mappable datasets | |||
| """ | |||
| import os | |||
| import pytest | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.c_transforms as c_vision | |||
| from mindspore import log as logger | |||
| @@ -25,6 +27,7 @@ DATA_DIR = "../data/dataset/testImageNetData/train/" | |||
| GENERATE_GOLDEN = False | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_map_basic1(): | |||
| """ | |||
| Test mappable leaf with cache op right over the leaf | |||
| @@ -53,7 +56,7 @@ def test_cache_map_basic1(): | |||
| logger.info("test_cache_map_basic1 Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_map_basic2(): | |||
| """ | |||
| Test mappable leaf with the cache op later in the tree above the map(decode) | |||
| @@ -82,7 +85,7 @@ def test_cache_map_basic2(): | |||
| logger.info("test_cache_map_basic2 Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_map_basic3(): | |||
| """ | |||
| Test a repeat under mappable cache | |||
| @@ -116,7 +119,7 @@ def test_cache_map_basic3(): | |||
| assert num_iter == 8 | |||
| logger.info('test_cache_basic3 Ended.\n') | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_map_basic4(): | |||
| """ | |||
| Test different rows result in core dump | |||
| @@ -141,7 +144,7 @@ def test_cache_map_basic4(): | |||
| assert num_iter == 8 | |||
| logger.info('test_cache_basic3 Ended.\n') | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_map_failure1(): | |||
| """ | |||
| Test nested cache (failure) | |||
| @@ -15,6 +15,8 @@ | |||
| """ | |||
| Testing cache operator with non-mappable datasets | |||
| """ | |||
| import os | |||
| import pytest | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.c_transforms as c_vision | |||
| @@ -25,6 +27,7 @@ SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | |||
| GENERATE_GOLDEN = False | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_nomap_basic1(): | |||
| """ | |||
| A random dataset (a non mappable dataset) with a cache over it just after the leaf | |||
| @@ -54,6 +57,7 @@ def test_cache_nomap_basic1(): | |||
| logger.info("test_cache_nomap_basic1 Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_nomap_basic2(): | |||
| """ | |||
| A random dataset (a non mappable dataset) with a cache over it just after the leaf | |||
| @@ -85,6 +89,7 @@ def test_cache_nomap_basic2(): | |||
| logger.info("test_cache_nomap_basic2 Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_nomap_basic3(): | |||
| """ | |||
| A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf | |||
| @@ -112,9 +117,21 @@ def test_cache_nomap_basic3(): | |||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | |||
| assert num_iter == 12 | |||
| # Contact the server to get the statistics | |||
| stat = some_cache.GetStat() | |||
| cache_sz = stat.avg_cache_sz | |||
| num_mem_cached = stat.num_mem_cached | |||
| num_disk_cached = stat.num_disk_cached | |||
| logger.info("Number of rows cached in memory: {}".format(num_mem_cached)) | |||
| logger.info("Number of rows spilled to disk: {}".format(num_disk_cached)) | |||
| logger.info("Average row cache size: {}".format(cache_sz)) | |||
| logger.info("test_cache_nomap_basic3 Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_nomap_basic4(): | |||
| """ | |||
| A TF reader dataset (a non mappable dataset) with a map decode and cache after it | |||
| @@ -155,6 +172,7 @@ def test_cache_nomap_basic4(): | |||
| logger.info("test_cache_nomap_basic4 Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_nomap_basic5(): | |||
| """ | |||
| A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf | |||
| @@ -191,6 +209,7 @@ def test_cache_nomap_basic5(): | |||
| logger.info("test_cache_nomap_basic5 Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_nomap_basic6(): | |||
| """ | |||
| A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf | |||
| @@ -230,6 +249,7 @@ def test_cache_nomap_basic6(): | |||
| logger.info("test_cache_nomap_basic6 Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_nomap_basic7(): | |||
| """ | |||
| A TF reader dataset (a non mappable dataset) that uses global shuffle, and is cached followed by | |||
| @@ -265,6 +285,7 @@ def test_cache_nomap_basic7(): | |||
| logger.info("test_cache_nomap_basic7 Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_nomap_allowed_share1(): | |||
| """ | |||
| It is allowed to share the cache between the following two trees: | |||
| @@ -280,7 +301,7 @@ def test_cache_nomap_allowed_share1(): | |||
| ds.config.set_seed(1) | |||
| # This dataset has 3 records in it only | |||
| some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True, prefetch_size=32) | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache) | |||
| ds1 = ds1.repeat(4) | |||
| @@ -300,6 +321,7 @@ def test_cache_nomap_allowed_share1(): | |||
| logger.info("test_cache_nomap_allowed_share1 Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_nomap_allowed_share2(): | |||
| """ | |||
| It is allowed to share the cache between the following two trees (with map decode): | |||
| @@ -341,6 +363,7 @@ def test_cache_nomap_allowed_share2(): | |||
| logger.info("test_cache_nomap_allowed_share2 Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_nomap_allowed_share3(): | |||
| """ | |||
| It is allowed to share the cache between the following two trees (different shard ids): | |||
| @@ -376,6 +399,7 @@ def test_cache_nomap_allowed_share3(): | |||
| logger.info("test_cache_nomap_allowed_share3 Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_nomap_allowed_share4(): | |||
| """ | |||
| It is allowed to share the cache between the following two trees: | |||
| @@ -414,6 +438,7 @@ def test_cache_nomap_allowed_share4(): | |||
| logger.info("test_cache_nomap_allowed_share4 Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_nomap_disallowed_share1(): | |||
| """ | |||
| It is not allowed to share the cache between the following two trees: | |||