Browse Source

Phase 2 of CacheOp

tags/v0.7.0-beta
Jesse Lee 5 years ago
parent
commit
8a08d0c37b
52 changed files with 3929 additions and 808 deletions
  1. +12
    -12
      mindspore/ccsrc/minddata/dataset/CMakeLists.txt
  2. +19
    -1
      mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/cache/bindings.cc
  3. +2
    -1
      mindspore/ccsrc/minddata/dataset/core/constants.h
  4. +3
    -5
      mindspore/ccsrc/minddata/dataset/engine/CMakeLists.txt
  5. +42
    -3
      mindspore/ccsrc/minddata/dataset/engine/cache/CMakeLists.txt
  6. +70
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin.cc
  7. +396
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc
  8. +105
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.h
  9. +73
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.cc
  10. +52
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.h
  11. +91
    -71
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc
  12. +144
    -18
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h
  13. +90
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_common.h
  14. +151
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_fbb.cc
  15. +46
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_fbb.h
  16. +54
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc.proto
  17. +161
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_client.cc
  18. +102
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_client.h
  19. +203
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.cc
  20. +103
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.h
  21. +121
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_main.cc
  22. +175
    -145
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc
  23. +207
    -85
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h
  24. +550
    -123
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc
  25. +154
    -14
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h
  26. +79
    -30
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc
  27. +18
    -4
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h
  28. +14
    -1
      mindspore/ccsrc/minddata/dataset/engine/cache/de_tensor.fbs
  29. +45
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/stub/cache_grpc_client.h
  30. +119
    -20
      mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc
  31. +15
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h
  32. +95
    -80
      mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc
  33. +30
    -14
      mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h
  34. +2
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc
  35. +7
    -0
      mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc
  36. +1
    -0
      mindspore/ccsrc/minddata/dataset/include/status.h
  37. +1
    -1
      mindspore/ccsrc/minddata/dataset/util/allocator.h
  38. +15
    -15
      mindspore/ccsrc/minddata/dataset/util/arena.h
  39. +9
    -0
      mindspore/ccsrc/minddata/dataset/util/cache_pool.cc
  40. +1
    -0
      mindspore/ccsrc/minddata/dataset/util/cache_pool.h
  41. +127
    -0
      mindspore/ccsrc/minddata/dataset/util/queue_map.h
  42. +11
    -43
      mindspore/ccsrc/minddata/dataset/util/services.cc
  43. +20
    -12
      mindspore/ccsrc/minddata/dataset/util/services.h
  44. +1
    -0
      mindspore/ccsrc/minddata/dataset/util/slice.h
  45. +3
    -0
      mindspore/ccsrc/minddata/dataset/util/status.cc
  46. +1
    -0
      mindspore/ccsrc/minddata/dataset/util/status.h
  47. +2
    -0
      mindspore/ccsrc/minddata/dataset/util/task_manager.cc
  48. +12
    -1
      mindspore/ccsrc/minddata/dataset/util/task_manager.h
  49. +11
    -2
      mindspore/dataset/engine/cache_client.py
  50. +131
    -100
      tests/ut/cpp/dataset/cache_op_test.cc
  51. +7
    -4
      tests/ut/python/dataset/test_cache_map.py
  52. +26
    -1
      tests/ut/python/dataset/test_cache_nomap.py

+ 12
- 12
mindspore/ccsrc/minddata/dataset/CMakeLists.txt View File

@@ -24,6 +24,11 @@ if (ENABLE_TDTQUE)
add_definitions(-D ENABLE_TDTQUE)
message(STATUS "TDT queue is enabled")
endif ()
if (MS_BUILD_GRPC)
set (ENABLE_CACHE true)
add_definitions(-D ENABLE_CACHE)
message(STATUS "Cache is enabled")
endif()

# conde coverage
# option(ENABLE_COVERAGE "Enable code coverage report" OFF)
@@ -47,10 +52,6 @@ include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ORIGIN:$ORIGIN/lib")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=default")

include_directories("${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache")
set(MD_FLATBUFFER_OU "${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache")
ms_build_flatbuffers("engine/cache/de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${MD_FLATBUFFER_OU})

################## Include sub-modules ###############################
add_subdirectory(util)
add_subdirectory(core)
@@ -70,8 +71,6 @@ add_dependencies(engine-datasetops-source-sampler core)
add_dependencies(engine-datasetops core)
add_dependencies(engine-datasetops-mapop core)
add_dependencies(engine-opt core)
add_dependencies(engine-cache-client core)
add_dependencies(engine-cache-server core)
add_dependencies(engine-perf core)
add_dependencies(engine-gnn core)
add_dependencies(engine core)
@@ -85,7 +84,11 @@ endif()
if (ENABLE_TDTQUE)
add_dependencies(engine-tdt core)
endif ()

if (ENABLE_CACHE)
add_dependencies(engine-datasetops engine-cache-client)
add_dependencies(engine-cache-client core)
add_dependencies(engine-cache-server core)
endif ()
################### Create _c_dataengine Library ######################
set(submodules
$<TARGET_OBJECTS:core>
@@ -105,7 +108,6 @@ set(submodules
$<TARGET_OBJECTS:engine-datasetops>
$<TARGET_OBJECTS:engine-opt>
$<TARGET_OBJECTS:engine-cache-client>
$<TARGET_OBJECTS:engine-cache-server>
$<TARGET_OBJECTS:engine>
$<TARGET_OBJECTS:text>
$<TARGET_OBJECTS:text-kernels>
@@ -123,8 +125,6 @@ else ()
add_library(_c_dataengine SHARED ${submodules})
endif ()

add_dependencies(_c_dataengine generated_engine_files)

if (ENABLE_PYTHON)
set_target_properties(_c_dataengine PROPERTIES
PREFIX "${PYTHON_MODULE_PREFIX}"
@@ -187,6 +187,6 @@ else()
endif ()
endif()

if (NOT CMAKE_SYSTEM_NAME MATCHES "Windows")
if (MS_BUILD_GRPC)
target_link_libraries(_c_dataengine PRIVATE mindspore::grpc++)
endif()
endif()

+ 19
- 1
mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/cache/bindings.cc View File

@@ -22,7 +22,25 @@ namespace dataset {

PYBIND_REGISTER(CacheClient, 0, ([](const py::module *m) {
(void)py::class_<CacheClient, std::shared_ptr<CacheClient>>(*m, "CacheClient")
.def(py::init<uint32_t, uint64_t, bool>());
.def(
py::init([](session_id_type id, uint64_t mem_sz, bool spill, int32_t port, int32_t prefetch_sz) {
std::shared_ptr<CacheClient> cc;
CacheClient::Builder builder;
builder.SetSessionId(id).SetCacheMemSz(mem_sz).SetSpill(spill).SetPort(port).SetPrefetchSize(
prefetch_sz);
THROW_IF_ERROR(builder.Build(&cc));
return cc;
}))
.def("GetStat", [](CacheClient &cc) {
CacheServiceStat stat{};
THROW_IF_ERROR(cc.GetStat(&stat));
return stat;
});
(void)py::class_<CacheServiceStat>(*m, "CacheServiceStat")
.def(py::init<>())
.def_readwrite("avg_cache_sz", &CacheServiceStat::avg_cache_sz)
.def_readwrite("num_mem_cached", &CacheServiceStat::num_mem_cached)
.def_readwrite("num_disk_cached", &CacheServiceStat::num_disk_cached);
}));

} // namespace dataset


+ 2
- 1
mindspore/ccsrc/minddata/dataset/core/constants.h View File

@@ -72,7 +72,8 @@ constexpr uint32_t kCfgMonitorSamplingInterval = 10;
// Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h)
constexpr uint8_t kCVInvalidType = 255;

using connection_id_type = int64_t;
using connection_id_type = uint64_t;
using session_id_type = uint32_t;
using row_id_type = int64_t;
} // namespace dataset
} // namespace mindspore


+ 3
- 5
mindspore/ccsrc/minddata/dataset/engine/CMakeLists.txt View File

@@ -20,10 +20,8 @@ if (ENABLE_PYTHON)
target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS})
endif()

add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf engine-cache-client engine-datasetops-mapop)

if (ENABLE_TDTQUE)
add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt engine-gnn engine-perf
engine-cache-client engine-cache-server engine-datasetops-mapop)
else ()
add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf
engine-cache-client engine-cache-server engine-datasetops-mapop)
add_dependencies(engine engine-tdt)
endif ()

+ 42
- 3
mindspore/ccsrc/minddata/dataset/engine/cache/CMakeLists.txt View File

@@ -1,8 +1,47 @@
include_directories("${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache")
set(MD_FLATBUFFER_OU "${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache")
ms_build_flatbuffers("de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${MD_FLATBUFFER_OU})

file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)

add_library(engine-cache-client OBJECT
cache_client.cc
cache_fbb.cc
cache_request.cc)
add_library(engine-cache-server OBJECT
cache_service.cc
cache_server.cc)

if (ENABLE_CACHE)
ms_grpc_generate(CACHE_GRPC_SRCS CACHE_GRPC_HDRS cache_grpc.proto)
target_sources(engine-cache-client PUBLIC ${CACHE_GRPC_SRCS} cache_grpc_client.cc)

add_library(engine-cache-server OBJECT
${CACHE_GRPC_SRCS}
cache_grpc_server.cc
cache_arena.cc
cache_service.cc
cache_server.cc)

add_executable(cache_server cache_main.cc)
target_link_libraries(cache_server
engine-cache-server
$<TARGET_OBJECTS:utils>
mindspore
mindspore::glog
mindspore::protobuf
mindspore::grpc++
mindspore_gvar
${PYTHON_LIBRARIES}
${SECUREC_LIBRARY}
pthread)

add_executable(cache_admin cache_admin.cc cache_admin_arg.cc)
target_link_libraries(cache_admin _c_dataengine _c_mindrecord ${PYTHON_LIBRARIES} mindspore::glog)

add_dependencies(engine-cache-server generated_engine_files)

else ()
ms_protobuf_generate(CACHE_PROTO_SRCS CACHE_PRTO_HDRS cache_grpc.proto)
target_sources(engine-cache-client PUBLIC ${CACHE_PROTO_SRCS})
endif ()

add_dependencies(engine-cache-client generated_engine_files)

+ 70
- 0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin.cc View File

@@ -0,0 +1,70 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <unistd.h>
#include <iostream>
#ifdef USE_GLOG
#include <glog/logging.h>
#endif
#include "minddata/dataset/engine/cache/cache_admin_arg.h"

namespace ds = mindspore::dataset;

int main(int argc, char **argv) {
ds::Status rc;
ds::CacheAdminArgHandler args;
std::stringstream arg_stream;

#ifdef USE_GLOG
FLAGS_log_dir = "/tmp";
google::InitGoogleLogging(argv[0]);
#endif

std::string warningMsg;
warningMsg.reserve(512);
warningMsg += "WARNING:\n";
warningMsg += "cache_admin and the cache server that it controls are currently only used for experimental research";
warningMsg += " purposes at this time.\n";
warningMsg += "It is not intended for general availability yet as it may not be stable. Use it at your own risk.\n";

// A warning message until the code is mature enough.
std::cerr << warningMsg << std::endl;

if (argc == 1) {
args.Help();
return 0;
}

// ingest all the args into a string stream for parsing
for (int i = 1; i < argc; ++i) {
arg_stream << " " << std::string(argv[i]);
}

// Parse the args
rc = args.ParseArgStream(&arg_stream);
if (!rc.IsOk()) {
std::cerr << rc.ToString() << std::endl;
return 1;
}

// Execute the command
rc = args.RunCommand();
if (!rc.IsOk()) {
std::cerr << rc.ToString() << std::endl;
return 1;
}

return 0;
}

+ 396
- 0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc View File

@@ -0,0 +1,396 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/cache_admin_arg.h"
#include <sys/types.h>
#include <sys/stat.h>
#include <sys/wait.h>
#include <unistd.h>
#include <cerrno>
#include <iostream>
#include <string>
#include <cstdlib>
#include "minddata/dataset/engine/cache/cache_request.h"
#include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/util/path.h"

namespace mindspore {
namespace dataset {

const char CacheAdminArgHandler::kDefaultHost[] = "127.0.0.1";
const char CacheAdminArgHandler::kServerBinary[] = "cache_server";
const char CacheAdminArgHandler::kDefaultSpillDir[] = "/tmp";

CacheAdminArgHandler::CacheAdminArgHandler()
: port_(kDefaultPort),
session_id_(0),
num_workers_(kDefaultNumWorkers),
shm_mem_sz_(kDefaultSharedMemorySizeInGB),
log_level_(kDefaultLogLevel),
hostname_(kDefaultHost),
spill_dir_(kDefaultSpillDir),
command_id_(CommandId::kCmdUnknown) {
// Initialize the command mappings
arg_map_["-h"] = ArgValue::kArgHost;
arg_map_["--hostname"] = ArgValue::kArgHost;
arg_map_["-p"] = ArgValue::kArgPort;
arg_map_["--port"] = ArgValue::kArgPort;
arg_map_["--start"] = ArgValue::kArgStart;
arg_map_["--stop"] = ArgValue::kArgStop;
arg_map_["--help"] = ArgValue::kArgHelp;
arg_map_["--generate_session"] = ArgValue::kArgGenerateSession;
arg_map_["-g"] = ArgValue::kArgGenerateSession;
arg_map_["--destroy_session"] = ArgValue::kArgDestroySession;
arg_map_["-d"] = ArgValue::kArgDestroySession;
arg_map_["--spilldir"] = ArgValue::kArgSpillDir;
arg_map_["-s"] = ArgValue::kArgSpillDir;
arg_map_["-w"] = ArgValue::kArgNumWorkers;
arg_map_["--workers"] = ArgValue::kArgNumWorkers;
arg_map_["-m"] = ArgValue::kArgSharedMemorySize;
arg_map_["--shared_memory_size"] = ArgValue::kArgSharedMemorySize;
arg_map_["-l"] = ArgValue::kArgLogLevel;
arg_map_["--minloglevel"] = ArgValue::kArgLogLevel;
// Initialize argument tracker with false values
for (int16_t i = 0; i < static_cast<int16_t>(ArgValue::kArgNumArgs); ++i) {
ArgValue currAV = static_cast<ArgValue>(i);
used_args_[currAV] = false;
}
}

Status CacheAdminArgHandler::AssignArg(std::string option, int32_t *out_arg, std::stringstream *arg_stream,
CommandId command_id) {
// Detect if the user tried to provide this argument more than once
ArgValue selected_arg = arg_map_[option];
if (used_args_[selected_arg]) {
std::string err_msg = "The " + option + " argument was given more than once.";
return Status(StatusCode::kSyntaxError, err_msg);
}

// Flag that this arg is used now
used_args_[selected_arg] = true;

// Some options are just arguments, for example "--port 50052" is not a command, it's just a argument.
// Other options are actual commands, for example "--destroy_session 1234". This executes the destroy session.
// If this option is also a command, make sure there has not been multiple commands given before assigning it.
if (command_id != CommandId::kCmdUnknown) {
if (command_id_ != CommandId::kCmdUnknown) {
std::string err_msg = "Only one command at a time is allowed. Invalid command: " + option;
return Status(StatusCode::kSyntaxError, err_msg);
} else {
command_id_ = command_id;
}
}

std::string value_as_string;

// Fetch the argument from the arg stream into a string
*arg_stream >> value_as_string;
if (value_as_string.empty()) {
std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>";
return Status(StatusCode::kSyntaxError, err_msg);
}

// Now, attempt to convert the value into it's string format for output
try {
*out_arg = std::stoul(value_as_string);
} catch (const std::exception &e) {
std::string err_msg = "Invalid numeric value: " + value_as_string;
return Status(StatusCode::kSyntaxError, err_msg);
}

return Status::OK();
}

Status CacheAdminArgHandler::AssignArg(std::string option, std::string *out_arg, std::stringstream *arg_stream,
CommandId command_id) {
// Detect if the user tried to provide this argument more than once
ArgValue selected_arg = arg_map_[option];
if (used_args_[selected_arg]) {
std::string err_msg = "The " + option + " argument was given more than once.";
return Status(StatusCode::kSyntaxError, err_msg);
}

// Flag that this arg is used now
used_args_[selected_arg] = true;

// Some options are just arguments, for example "--hostname "127.0.0.1" is not a command, it's just an argument.
// Other options are actual commands, for example "--start".
// If this option is also a command, make sure there has not been multiple commands given before assigning it.
if (command_id != CommandId::kCmdUnknown) {
if (command_id_ != CommandId::kCmdUnknown) {
std::string err_msg = "Only one command at a time is allowed. Invalid command: " + option;
return Status(StatusCode::kSyntaxError, err_msg);
} else {
command_id_ = command_id;
}
}

// If there is no argument to get, such as the --start command, then out_arg will be a nullptr.
if (out_arg != nullptr) {
// Fetch the argument from the arg stream into a string
*arg_stream >> *out_arg;
if (out_arg->empty()) {
std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>";
return Status(StatusCode::kSyntaxError, err_msg);
}
}

return Status::OK();
}

Status CacheAdminArgHandler::ParseArgStream(std::stringstream *arg_stream) {
std::string tok;
while (*arg_stream >> tok) {
switch (arg_map_[tok]) {
case ArgValue::kArgHost: {
RETURN_IF_NOT_OK(AssignArg(tok, &hostname_, arg_stream));
break;
}
case ArgValue::kArgPort: {
RETURN_IF_NOT_OK(AssignArg(tok, &port_, arg_stream));
break;
}
case ArgValue::kArgStart: {
RETURN_IF_NOT_OK(AssignArg(tok, static_cast<std::string *>(nullptr), arg_stream, CommandId::kCmdStart));
break;
}
case ArgValue::kArgStop: {
RETURN_IF_NOT_OK(AssignArg(tok, static_cast<std::string *>(nullptr), arg_stream, CommandId::kCmdStop));
break;
}
case ArgValue::kArgGenerateSession: {
RETURN_IF_NOT_OK(
AssignArg(tok, static_cast<std::string *>(nullptr), arg_stream, CommandId::kCmdGenerateSession));
break;
}
case ArgValue::kArgHelp: {
command_id_ = CommandId::kCmdHelp;
break;
}
case ArgValue::kArgDestroySession: {
// session_id is an unsigned type. We may need to template the AssignArg function so that
// it can handle different flavours of integers instead of just int32_t.
int32_t session_int;
RETURN_IF_NOT_OK(AssignArg(tok, &session_int, arg_stream, CommandId::kCmdDestroySession));
session_id_ = session_int;
break;
}
case ArgValue::kArgNumWorkers: {
RETURN_IF_NOT_OK(AssignArg(tok, &num_workers_, arg_stream));
break;
}
case ArgValue::kArgSpillDir: {
RETURN_IF_NOT_OK(AssignArg(tok, &spill_dir_, arg_stream));
break;
}
case ArgValue::kArgSharedMemorySize: {
RETURN_IF_NOT_OK(AssignArg(tok, &shm_mem_sz_, arg_stream));
break;
}
case ArgValue::kArgLogLevel: {
RETURN_IF_NOT_OK(AssignArg(tok, &log_level_, arg_stream));
break;
}
default: {
// Save space delimited trailing arguments
trailing_args_ += (" " + tok);
break;
}
}
}

RETURN_IF_NOT_OK(Validate());

return Status::OK();
}

Status CacheAdminArgHandler::Validate() {
// This sanity check is delayed until now in case there may be valid use-cases of trailing args.
// Any unhandled arguments at this point is an error.
if (!trailing_args_.empty()) {
std::string err_msg = "Invalid arguments provided: " + trailing_args_;
return Status(StatusCode::kSyntaxError, err_msg);
}

// The user must pick at least one command. i.e. it's meaningless to just give a hostname or port but no command to
// run.
if (command_id_ == CommandId::kCmdUnknown) {
std::string err_msg = "No command provided";
return Status(StatusCode::kSyntaxError, err_msg);
}

// Additional checks here
if (num_workers_ < 1) return Status(StatusCode::kSyntaxError, "Number of workers must be positive value.");
if (log_level_ < 0 || log_level_ > 3) return Status(StatusCode::kSyntaxError, "Log level must be in range (0..3).");
// port range check?

return Status::OK();
}

Status CacheAdminArgHandler::RunCommand() {
switch (command_id_) {
case CommandId::kCmdHelp: {
Help();
break;
}
case CommandId::kCmdStart: {
RETURN_IF_NOT_OK(StartServer());
break;
}
case CommandId::kCmdStop: {
RETURN_IF_NOT_OK(StopServer());
break;
}
case CommandId::kCmdGenerateSession: {
CacheClientGreeter comm(hostname_, port_, 1);
RETURN_IF_NOT_OK(comm.ServiceStart());
auto rq = std::make_shared<GenerateSessionIdRequest>();
RETURN_IF_NOT_OK(comm.HandleRequest(rq));
RETURN_IF_NOT_OK(rq->Wait());
std::cout << rq->GetSessionId() << std::endl;
break;
}
case CommandId::kCmdDestroySession: {
CacheClientGreeter comm(hostname_, port_, 1);
RETURN_IF_NOT_OK(comm.ServiceStart());
CacheClientInfo cinfo;
cinfo.set_session_id(session_id_);
auto rq = std::make_shared<DropSessionRequest>(cinfo);
RETURN_IF_NOT_OK(comm.HandleRequest(rq));
RETURN_IF_NOT_OK(rq->Wait());
std::cout << "Drop session successful" << std::endl;
break;
}
default: {
RETURN_STATUS_UNEXPECTED("Invalid cache admin command id.");
break;
}
}

return Status::OK();
}

Status CacheAdminArgHandler::StartServer() {
// There currently does not exist any "install path" or method to identify which path the installed binaries will
// exist in. As a temporary approach, we will assume that the server binary shall exist in the same path as the
// cache_admin binary (this process).
const std::string self_proc = "/proc/self/exe";
std::string canonical_path;
canonical_path.resize(400); // PATH_MAX is large. This value should be big enough for our use.
// Some lower level OS library calls are needed here to determine the binary path.
// Fetch the path of this binary for admin_cache into C character array and then truncate off the binary name so that
// we are left with only the absolute path
if (realpath(self_proc.data(), canonical_path.data()) == nullptr) {
std::string err_msg = "Failed to identify cache admin binary path: " + std::to_string(errno);
RETURN_STATUS_UNEXPECTED(err_msg);
}
canonical_path.resize(strlen(canonical_path.data()));
int last_seperator = canonical_path.find_last_of('/');
CHECK_FAIL_RETURN_UNEXPECTED(last_seperator != std::string::npos, "No / found");
// truncate the binary name so we are left with the absolute path of cache_admin binary
canonical_path.resize(last_seperator + 1);
std::string cache_server_binary = canonical_path + std::string(kServerBinary);

// Create a pipe before we fork. If all goes well, the child will run as a daemon in the background
// and never returns until shutdown. If there is any error, the child will notify us through the pipe.
int fd[2];
if (pipe(fd) == -1) {
std::string err_msg = "Failed to create a pipe for communication " + std::to_string(errno);
RETURN_STATUS_UNEXPECTED(err_msg);
}

// fork the child process to become the daemon
pid_t pid;
pid = fork();

// failed to fork
if (pid < 0) {
std::string err_msg = "Failed to fork process for cache server: " + std::to_string(errno);
RETURN_STATUS_UNEXPECTED(err_msg);
} else if (pid > 0) {
// As a parent, we close the write end. We only listen.
close(fd[1]);
dup2(fd[0], 0);
close(fd[0]);
wait(nullptr);
std::string msg;
const int32_t buf_sz = 1024;
msg.resize(buf_sz);
auto n = read(0, msg.data(), buf_sz);
if (n < 0) {
std::string err_msg = "Failed to read from pipeline " + std::to_string(errno);
RETURN_STATUS_UNEXPECTED(err_msg);
}
msg.resize(n);
std::cout << msg << std::endl;
return Status::OK();
} else {
// Child here ...
// Close all stdin, redirect stdout and stderr to the write end of the pipe.
close(fd[0]);
dup2(fd[1], 1);
dup2(fd[1], 2);
close(0);
close(fd[1]);
// exec the cache server binary in this process
std::string port_string = std::to_string(port_);
std::string workers_string = std::to_string(num_workers_);
std::string shared_memory_string = std::to_string(shm_mem_sz_);
std::string minloglevel_string = std::to_string(log_level_);
std::string daemonize_string = "true";

char *argv[8];
argv[0] = cache_server_binary.data(); // First arg is usually the binary name
argv[1] = spill_dir_.data();
argv[2] = workers_string.data();
argv[3] = port_string.data();
argv[4] = shared_memory_string.data();
argv[5] = minloglevel_string.data();
argv[6] = daemonize_string.data();
argv[7] = nullptr;

// Now exec the binary
execv(argv[0], argv);
// If the exec was successful, this line will never be reached due to process image being replaced.
// ..unless exec failed.
std::string err_msg = "Failed to exec cache server: " + cache_server_binary;
std::cerr << err_msg << std::endl;
RETURN_STATUS_UNEXPECTED(err_msg);
}
}

Status CacheAdminArgHandler::StopServer() {
CacheClientGreeter comm(hostname_, port_, 1);
RETURN_IF_NOT_OK(comm.ServiceStart());
auto rq = std::make_shared<ShutdownRequest>();
RETURN_IF_NOT_OK(comm.HandleRequest(rq));
return Status::OK();
}

void CacheAdminArgHandler::Help() {
std::cerr << "Syntax:\n";
std::cerr << " cache_admin [--start | --stop]\n";
std::cerr << " [ [-h | --hostname] <hostname> ]\n";
std::cerr << " [ [-p | --port] <port number> ]\n";
std::cerr << " [ [-g | --generate_session] ]\n";
std::cerr << " [ [-d | --destroy_session] <session id> ]\n";
std::cerr << " [ [-w | --workers] <number of workers> ]\n";
std::cerr << " [ [-s | --spilldir] <spilling directory> ]\n";
std::cerr << " [ [-m | --shared_memory_size] <shared memory size> ]\n";
std::cerr << " [ [-l | --minloglevel] <log level> ]\n";
std::cerr << " [--help]" << std::endl;
}
} // namespace dataset
} // namespace mindspore

+ 105
- 0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.h View File

@@ -0,0 +1,105 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ADMIN_ARG_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ADMIN_ARG_H_

#include <iostream>
#include <map>
#include <memory>
#include <string>
#include <sstream>
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/engine/cache/cache_client.h"

namespace mindspore {
namespace dataset {

class CacheAdminArgHandler {
public:
static constexpr int32_t kDefaultPort = 50052;
static constexpr int32_t kDefaultNumWorkers = 32;
static constexpr int32_t kDefaultSharedMemorySizeInGB = 4;
static constexpr int32_t kDefaultLogLevel = 1;
static const char kDefaultHost[];
static const char kServerBinary[];
static const char kDefaultSpillDir[];

// These are the actual command types to execute
enum class CommandId : int16_t {
kCmdHelp = 0,
kCmdStart = 1,
kCmdStop = 2,
kCmdGenerateSession = 3,
kCmdDestroySession = 4,
kCmdUnknown = 32767
};

CacheAdminArgHandler();

~CacheAdminArgHandler() = default;

Status ParseArgStream(std::stringstream *arg_stream);

Status RunCommand();

void Help();

private:
// These are the supported argument string integer mappings
enum class ArgValue : int16_t {
kArgUnknown = 0, // Must be at position 0. invalid map lookups in arg_map_ default to value 0
kArgStart = 1,
kArgStop = 2,
kArgHost = 3,
kArgPort = 4,
kArgHelp = 5,
kArgGenerateSession = 6,
kArgDestroySession = 7,
kArgSpillDir = 8,
kArgNumWorkers = 9,
kArgSharedMemorySize = 10,
kArgLogLevel = 11,
kArgNumArgs = 12 // Must be the last position to provide a count
};

Status StartServer();

Status StopServer();

Status AssignArg(std::string option, int32_t *out_arg, std::stringstream *arg_stream,
CommandId command_id = CommandId::kCmdUnknown);

Status AssignArg(std::string option, std::string *out_arg, std::stringstream *arg_stream,
CommandId command_id = CommandId::kCmdUnknown);

Status Validate();

CommandId command_id_;
int32_t port_;
int32_t num_workers_;
int32_t shm_mem_sz_;
int32_t log_level_;
session_id_type session_id_;
std::string hostname_;
std::string spill_dir_;
std::string trailing_args_;
std::map<std::string, ArgValue> arg_map_;
std::map<ArgValue, bool> used_args_;
};
} // namespace dataset
} // namespace mindspore

#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ADMIN_ARG_H_

+ 73
- 0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.cc View File

@@ -0,0 +1,73 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/cache_arena.h"
#include "minddata/dataset/util/path.h"
namespace mindspore {
namespace dataset {
CachedSharedMemoryArena::CachedSharedMemoryArena(int32_t port, size_t val_in_GB)
: Arena::Arena(val_in_GB * 1024), port_(port), shmid_(-1) {}

CachedSharedMemoryArena::~CachedSharedMemoryArena() {
#if CACHE_LOCAL_CLIENT
if (this->ptr_ != nullptr && this->ptr_ != reinterpret_cast<void *>(-1)) {
shmdt(this->ptr_);
}
this->ptr_ = nullptr;
if (shmid_ != -1) {
shmctl(shmid_, IPC_RMID, nullptr);
// Also remove the path we use to generate ftok.
Path p(PortToUnixSocketPath(port_));
(void)p.Remove();
}
#endif
}

Status CachedSharedMemoryArena::CreateArena(std::unique_ptr<CachedSharedMemoryArena> *out, int32_t port,
size_t val_in_GB) {
RETURN_UNEXPECTED_IF_NULL(out);
#if CACHE_LOCAL_CLIENT
auto ba = new (std::nothrow) CachedSharedMemoryArena(port, val_in_GB);
if (ba == nullptr) {
return Status(StatusCode::kOutOfMemory);
}
// Transfer the ownership of this pointer. Any future error in the processing we will have
// the destructor of *out to deal.
(*out).reset(ba);
// Generate the ftok using a combination of port.
int err;
auto shm_key = PortToFtok(port, &err);
if (shm_key == (key_t)-1) {
std::string errMsg = "Ftok failed with errno " + std::to_string(err);
RETURN_STATUS_UNEXPECTED(errMsg);
}
auto access_mode = S_IRUSR | S_IWUSR | S_IROTH | S_IWOTH | S_IRGRP | S_IWGRP;
ba->shmid_ = shmget(shm_key, ba->size_in_bytes_, IPC_CREAT | IPC_EXCL | access_mode);
if (ba->shmid_) {
ba->ptr_ = shmat(ba->shmid_, nullptr, 0);
if (ba->ptr_ == reinterpret_cast<void *>(-1)) {
RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno));
}
} else {
RETURN_STATUS_UNEXPECTED("Shared memory creation failed. Errno " + std::to_string(errno));
}
uint64_t num_blks = ba->size_in_bytes_ / ARENA_BLK_SZ;
MS_LOG(DEBUG) << "Size of memory pool is " << num_blks << ", number of blocks of size is " << ARENA_BLK_SZ << ".";
ba->tr_.Insert(0, num_blks);
#endif
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 52
- 0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.h View File

@@ -0,0 +1,52 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ARENA_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ARENA_H_

#include <memory>
#include <string>
#include "minddata/dataset/util/arena.h"
#include "minddata/dataset/engine/cache/cache_common.h"
namespace mindspore {
namespace dataset {
/// This is a derived class of Arena but resides in shared memory
class CachedSharedMemoryArena : public Arena {
public:
~CachedSharedMemoryArena() override;
/// \brief Create an Arena in shared memory
/// \param[out] p_ba Pointer to a unique_ptr
/// \param shmkey Shared memory key
/// \param val_in_GB size of shared memory in gigabyte
/// \return Status object
static Status CreateArena(std::unique_ptr<CachedSharedMemoryArena> *out, int32_t port, size_t val_in_GB);

/// \brief This returns where we attach to the shared memory.
/// Some gRPC requests will ask for a shared memory block, and
/// we can't return the absolute address as this makes no sense
/// in the client. So instead we will return an address relative
/// to the base address of the shared memory where we attach to.
/// \return Base address of the shared memory.
const void *SharedMemoryBaseAddr() const { return this->ptr_; }

private:
int32_t port_;
int shmid_;
/// Private constructor. Not to be called directly.
CachedSharedMemoryArena(int32_t port, size_t val_in_GB);
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ARENA_H_

+ 91
- 71
mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc View File

@@ -17,29 +17,45 @@
#include <iomanip>
#include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/engine/cache/cache_request.h"
#include "minddata/dataset/engine/cache/cache_service.h"
#include "minddata/dataset/engine/cache/cache_fbb.h"
#include "minddata/dataset/util/bit.h"

namespace mindspore {
namespace dataset {

// Constructor
CacheClient::CacheClient(uint32_t session_id, uint64_t cache_mem_sz, bool spill)
: server_connection_id_(0), session_id_(session_id), cache_crc_(0), cache_mem_sz_(cache_mem_sz), spill_(spill) {}
CacheClient::CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool spill, std::string hostname,
int32_t port, int32_t num_workers, int32_t prefetch_size)
: server_connection_id_(0),
cache_mem_sz_(cache_mem_sz),
spill_(spill),
local_bypass_(false),
hostname_(std::move(hostname)),
port_(port),
num_workers_(num_workers),
prefetch_size_(prefetch_size) {
cinfo_.set_session_id(session_id);
comm_ = std::make_shared<CacheClientGreeter>(hostname_, port_, num_workers_);
}

// print method for display cache details
void CacheClient::Print(std::ostream &out) const {
out << " Session id: " << session_id_ << "\n Cache crc: " << cache_crc_
<< "\n Server cache id: " << server_connection_id_ << "\n Cache mem size: " << cache_mem_sz_
<< "\n Spilling: " << std::boolalpha << spill_;
out << " Session id: " << session_id() << "\n Cache crc: " << cinfo_.crc()
<< "\n Server cache id: " << server_connection_id_ << "\n Cache mem size: " << getCacheMemSz()
<< "\n Spilling: " << std::boolalpha << isSpill() << "\n Hostname: " << getHostname()
<< "\n Port: " << getPort() << "\n Number of rpc workers: " << getNumWorkers()
<< "\n Prefetch size: " << getPrefetchSize() << "\n Local client support: " << std::boolalpha
<< SupportLocalClient();
}

Status CacheClient::WriteRow(const TensorRow &row, row_id_type *row_id_from_server) const {
CacheRowRequest rq(server_connection_id_, cookie());
RETURN_IF_NOT_OK(rq.SerializeCacheRowRequest(row));
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
RETURN_IF_NOT_OK(rq.Wait());
auto rq = std::make_shared<CacheRowRequest>(server_connection_id_, cookie(), SupportLocalClient());
RETURN_IF_NOT_OK(rq->SerializeCacheRowRequest(this, row));
RETURN_IF_NOT_OK(PushRequest(rq));
RETURN_IF_NOT_OK(rq->Wait());
if (row_id_from_server != nullptr) {
*row_id_from_server = rq.GetRowIdAfterCache();
*row_id_from_server = rq->GetRowIdAfterCache();
}
return Status::OK();
}
@@ -47,29 +63,19 @@ Status CacheClient::WriteRow(const TensorRow &row, row_id_type *row_id_from_serv
Status CacheClient::WriteBuffer(std::unique_ptr<DataBuffer> &&in) const {
std::unique_ptr<DataBuffer> db_ptr = std::move(in);
auto num_rows = db_ptr->NumRows();
std::vector<TensorRow> all_rows;
// We will send the requests async first on all rows and do a final wait.
if (num_rows > 0) {
all_rows.reserve(num_rows);
// Break down the DataBuffer into TensorRow. We will send the requests async
// and then do a final wait.
MemGuard<CacheRowRequest> rq_arr;
RETURN_IF_NOT_OK(rq_arr.allocate(num_rows, server_connection_id_, cookie()));
CacheServer &cs = CacheServer::GetInstance();
auto arr = std::make_unique<std::shared_ptr<CacheRowRequest>[]>(num_rows);
for (auto i = 0; i < num_rows; ++i) {
TensorRow row;
auto rq = rq_arr[i];
RETURN_IF_NOT_OK(db_ptr->PopRow(&row));
RETURN_IF_NOT_OK(rq->SerializeCacheRowRequest(row));
RETURN_IF_NOT_OK(cs.PushRequest(rq));
// We can't let row go out of scope. Otherwise it will free all the tensor memory.
// So park it in the vector. When this function go out of scope, its memory
// will be freed.
all_rows.push_back(std::move(row));
arr[i] = std::make_shared<CacheRowRequest>(server_connection_id_, cookie(), SupportLocalClient());
RETURN_IF_NOT_OK(arr[i]->SerializeCacheRowRequest(this, row));
RETURN_IF_NOT_OK(PushRequest(arr[i]));
}
// Now we wait for the requests to be done.
// Now we wait for them to come back
for (auto i = 0; i < num_rows; ++i) {
auto rq = rq_arr[i];
RETURN_IF_NOT_OK(rq->Wait());
RETURN_IF_NOT_OK(arr[i]->Wait());
}
}
return Status::OK();
@@ -77,11 +83,21 @@ Status CacheClient::WriteBuffer(std::unique_ptr<DataBuffer> &&in) const {

Status CacheClient::GetRows(const std::vector<row_id_type> &row_id, TensorTable *out) const {
RETURN_UNEXPECTED_IF_NULL(out);
BatchFetchRequest rq(server_connection_id_, row_id);
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
RETURN_IF_NOT_OK(rq.Wait());
RETURN_IF_NOT_OK(rq.RestoreRows(out));
return Status::OK();
auto rq = std::make_shared<BatchFetchRequest>(server_connection_id_, row_id, SupportLocalClient());
RETURN_IF_NOT_OK(PushRequest(rq));
RETURN_IF_NOT_OK(rq->Wait());
int64_t mem_addr;
Status rc = rq->RestoreRows(out, comm_->SharedMemoryBaseAddr(), &mem_addr);
// Free the memory by sending a request back to the server.
if (mem_addr != -1) {
auto mfree_req = std::make_shared<FreeSharedBlockRequest>(server_connection_id_, mem_addr);
Status rc2 = PushRequest(mfree_req);
// But we won't wait for the result for the sake of performance.
if (rc.IsOk() && rc2.IsError()) {
rc = rc2;
}
}
return rc;
}

Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) {
@@ -108,40 +124,44 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) {
// to create a cache and some other tree is trying to use the same cache.
// That is allowed, however the crc better match!
if (server_connection_id_) {
if (cache_crc_ != tree_crc) {
if (cinfo_.crc() != tree_crc) {
RETURN_STATUS_UNEXPECTED("Attempt to re-use a cache for a different tree!");
}
// Check the state of the server. For non-mappable case where there is a build phase and a fetch phase, we should
// skip the build phase.
lck.Unlock(); // GetStat will grab the mutex again. So unlock it to prevent deadlock.
CacheClient::ServiceStat stat{};
CacheServiceStat stat{};
RETURN_IF_NOT_OK(GetStat(&stat));
if (stat.cache_service_state == static_cast<uint8_t>(CacheService::State::kFetchPhase)) {
return Status(StatusCode::kDuplicateKey, __LINE__, __FILE__, "Not an error and we should bypass the build phase");
}
} else {
cache_crc_ = tree_crc; // It's really a new cache we're creating so save our crc in the client
// Combine the session and crc. This will form our client cache identifier.
connection_id_type connection_identification = (static_cast<uint64_t>(session_id_) << 32) | cache_crc_;
cinfo_.set_crc(tree_crc); // It's really a new cache we're creating so save our crc in the client
// Now execute the cache create request using this identifier and other configs
BaseRequest::CreateCacheFlag createFlag = BaseRequest::CreateCacheFlag::kNone;
CreateCacheRequest::CreateCacheFlag createFlag = CreateCacheRequest::CreateCacheFlag::kNone;
if (spill_) {
createFlag |= BaseRequest::CreateCacheFlag::kSpillToDisk;
createFlag |= CreateCacheRequest::CreateCacheFlag::kSpillToDisk;
}
if (generate_id) {
createFlag |= BaseRequest::CreateCacheFlag::kGenerateRowId;
createFlag |= CreateCacheRequest::CreateCacheFlag::kGenerateRowId;
}
CreationCacheRequest rq(connection_identification, cache_mem_sz_, createFlag);
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
Status rc = rq.Wait();
// Start the comm layer to receive reply
RETURN_IF_NOT_OK(comm_->ServiceStart());
// Initiate connection
auto rq = std::make_shared<CreateCacheRequest>(cinfo_, cache_mem_sz_, createFlag);
RETURN_IF_NOT_OK(PushRequest(rq));
Status rc = rq->Wait();
if (rc.IsOk() || rc.get_code() == StatusCode::kDuplicateKey) {
server_connection_id_ = rq.GetServerConnectionId();
std::string cookie;
rq->ParseResult(&server_connection_id_, &cookie);
if (rc.IsOk()) {
// The 1st guy creating the cache will get a cookie back.
// But this object may be shared among pipelines and we don't want
// overwrite it.
cookie_ = rq.cookie();
cookie_ = cookie;
}
// Attach to shared memory for local client
RETURN_IF_NOT_OK(comm_->AttachToSharedMemory(port_, &local_bypass_));
}
// We are not resetting the Duplicate key return code. We are passing it back to the CacheOp. This will tell the
// CacheOp to bypass the build phase.
@@ -152,57 +172,57 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) {

Status CacheClient::PurgeCache() {
UniqueLock lck(&mux_);
PurgeCacheRequest rq(server_connection_id_);
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
return rq.Wait();
auto rq = std::make_shared<PurgeCacheRequest>(server_connection_id_);
RETURN_IF_NOT_OK(PushRequest(rq));
RETURN_IF_NOT_OK(rq->Wait());
return Status::OK();
}

Status CacheClient::DestroyCache() {
UniqueLock lck(&mux_);
DestroyCacheRequest rq(server_connection_id_);
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
return rq.Wait();
auto rq = std::make_shared<DestroyCacheRequest>(server_connection_id_);
RETURN_IF_NOT_OK(PushRequest(rq));
RETURN_IF_NOT_OK(rq->Wait());
return Status::OK();
}

Status CacheClient::GetStat(ServiceStat *stat) {
Status CacheClient::GetStat(CacheServiceStat *stat) {
SharedLock lck(&mux_);
RETURN_UNEXPECTED_IF_NULL(stat);
GetStatRequest rq(server_connection_id_);
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
RETURN_IF_NOT_OK(rq.Wait());
stat->num_disk_cached = rq.GetNumDiskCached();
stat->num_mem_cached = rq.GetNumMemCached();
stat->min_row_id = rq.GetMinRowId();
stat->max_row_id = rq.GetMaxRowId();
stat->cache_service_state = rq.GetState();
auto rq = std::make_shared<GetStatRequest>(server_connection_id_);
RETURN_IF_NOT_OK(PushRequest(rq));
RETURN_IF_NOT_OK(rq->Wait());
rq->GetStat(stat);
return Status::OK();
}

Status CacheClient::CacheSchema(const std::unordered_map<std::string, int32_t> &map) {
SharedLock lck(&mux_);
CacheSchemaRequest rq(server_connection_id_);
RETURN_IF_NOT_OK(rq.SerializeCacheSchemaRequest(map));
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
RETURN_IF_NOT_OK(rq.Wait());
auto rq = std::make_shared<CacheSchemaRequest>(server_connection_id_);
RETURN_IF_NOT_OK(rq->SerializeCacheSchemaRequest(map));
RETURN_IF_NOT_OK(PushRequest(rq));
RETURN_IF_NOT_OK(rq->Wait());
return Status::OK();
}

Status CacheClient::FetchSchema(std::unordered_map<std::string, int32_t> *map) {
SharedLock lck(&mux_);
RETURN_UNEXPECTED_IF_NULL(map);
FetchSchemaRequest rq(server_connection_id_);
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
RETURN_IF_NOT_OK(rq.Wait());
*map = rq.GetColumnMap();
auto rq = std::make_shared<FetchSchemaRequest>(server_connection_id_);
RETURN_IF_NOT_OK(PushRequest(rq));
RETURN_IF_NOT_OK(rq->Wait());
*map = rq->GetColumnMap();
return Status::OK();
}

Status CacheClient::BuildPhaseDone() const {
SharedLock lck(&mux_);
BuildPhaseDoneRequest rq(server_connection_id_, cookie());
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
RETURN_IF_NOT_OK(rq.Wait());
auto rq = std::make_shared<BuildPhaseDoneRequest>(server_connection_id_, cookie());
RETURN_IF_NOT_OK(PushRequest(rq));
RETURN_IF_NOT_OK(rq->Wait());
return Status::OK();
}

Status CacheClient::PushRequest(std::shared_ptr<BaseRequest> rq) const { return comm_->HandleRequest(std::move(rq)); }
} // namespace dataset
} // namespace mindspore

+ 144
- 18
mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h View File

@@ -23,9 +23,13 @@
#include <utility>
#include <vector>

#include "minddata/dataset/core/config_manager.h"
#ifdef ENABLE_CACHE
#include "minddata/dataset/engine/cache/cache_grpc_client.h"
#else
#include "minddata/dataset/engine/cache/stub/cache_grpc_client.h"
#endif
#include "minddata/dataset/engine/data_buffer.h"
#include "minddata/dataset/engine/cache/cache_server.h"
#include "minddata/dataset/engine/cache/de_tensor_generated.h"
#include "minddata/dataset/util/lock.h"

namespace mindspore {
@@ -35,18 +39,120 @@ namespace dataset {
/// rows, etc.
class CacheClient {
public:
friend class CacheMergeOp;

/// \brief A builder to help creating a CacheClient object
class Builder {
public:
Builder() : session_id_(0), cache_mem_sz_(0), spill_(false), port_(0), num_workers_(0), prefetch_size_(0) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
hostname_ = "127.0.0.1";
port_ = 50052;
num_workers_ = cfg->num_parallel_workers();
prefetch_size_ = 20; // rows_per_buf is too small (1 by default).
}

/// Setter function to set the session id
/// \param session_id
/// \return Builder object itself.
Builder &SetSessionId(session_id_type session_id) {
session_id_ = session_id;
return *this;
}

/// Setter function to set the cache memory size
/// \param cache_mem_sz
/// \return Builder object itself
Builder &SetCacheMemSz(uint64_t cache_mem_sz) {
cache_mem_sz_ = cache_mem_sz;
return *this;
}

/// Setter function to spill attribute
/// \param spill
/// Builder object itself
Builder &SetSpill(bool spill) {
spill_ = spill;
return *this;
}

/// Setter function to set rpc hostname
/// \param host
/// \return Builder object itself
Builder &SetHostname(std::string host) {
hostname_ = std::move(host);
return *this;
}

/// Setter function to set tcpip port
/// \param port
/// \return Builder object itself.
Builder &SetPort(int32_t port) {
port_ = port;
return *this;
}

/// Setter function to set number of async rpc workers
/// \param num_workers
/// \return Builder object itself
Builder &SetNumWorkers(int32_t num_workers) {
num_workers_ = num_workers;
return *this;
}

/// Setter function to set prefetch amount for fetching rows from cache server
/// \param prefetch_sz
/// \return Builder object itself
Builder &SetPrefetchSize(int32_t prefetch_sz) {
prefetch_size_ = prefetch_sz;
return *this;
}

/// Getter functions
session_id_type getSessionId() const { return session_id_; }
uint64_t getCacheMemSz() const { return cache_mem_sz_; }
bool isSpill() const { return spill_; }
const std::string &getHostname() const { return hostname_; }
int32_t getPort() const { return port_; }
int32_t getNumWorkers() const { return num_workers_; }
int32_t getPrefetchSize() const { return prefetch_size_; }

Status SanityCheck() {
CHECK_FAIL_RETURN_UNEXPECTED(session_id_ > 0, "session id must be positive");
CHECK_FAIL_RETURN_UNEXPECTED(cache_mem_sz_ >= 0, "cache memory size must not be negative. (0 implies unlimited");
CHECK_FAIL_RETURN_UNEXPECTED(num_workers_ > 0, "rpc workers must be positive");
CHECK_FAIL_RETURN_UNEXPECTED(prefetch_size_ > 0, "prefetch size must be positive");
CHECK_FAIL_RETURN_UNEXPECTED(!hostname_.empty(), "hostname must not be empty");
return Status::OK();
}

Status Build(std::shared_ptr<CacheClient> *out) {
RETURN_UNEXPECTED_IF_NULL(out);
RETURN_IF_NOT_OK(SanityCheck());
*out = std::make_shared<CacheClient>(session_id_, cache_mem_sz_, spill_, hostname_, port_, num_workers_,
prefetch_size_);
return Status::OK();
}

private:
session_id_type session_id_;
uint64_t cache_mem_sz_;
bool spill_;
std::string hostname_;
int32_t port_;
int32_t num_workers_;
int32_t prefetch_size_;
};

/// \brief Constructor
/// \param session_id A user assigned session id for the current pipeline
/// \param cache_mem_sz Size of the memory set aside for the row caching. 0 for unlimited
/// \param spill Spill to disk if out of memory
CacheClient(uint32_t session_id, uint64_t cache_mem_sz, bool spill);
CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool spill, std::string hostname, int32_t port,
int32_t num_workers, int32_t prefetch_size);

/// \brief Destructor
~CacheClient() = default;

/// \brief Getter function for returning the current session id
/// \return session id
uint64_t session_id() const { return session_id_; }
~CacheClient() { (void)comm_->ServiceStop(); }

/// \brief Send a TensorRow to the cache server
/// \param[in] row
@@ -83,14 +189,7 @@ class CacheClient {
/// \brief Get the statistics from a cache.
/// \param[in/out] Pointer to a pre-allocated ServiceStat object
/// \return Status object
struct ServiceStat {
int64_t num_mem_cached;
int64_t num_disk_cached;
row_id_type min_row_id;
row_id_type max_row_id;
int8_t cache_service_state;
};
Status GetStat(ServiceStat *);
Status GetStat(CacheServiceStat *);

/// \brief Cache the schema at the cache server
/// \param map The unordered map of the schema
@@ -122,18 +221,45 @@ class CacheClient {
/// \return Cookie
std::string cookie() const { return cookie_; }

/// \brief Send a request async to the server
/// \param rq BaseRequest
/// \return Status object
Status PushRequest(std::shared_ptr<BaseRequest> rq) const;

/// \brief If the remote server supports local bypass using shared memory
/// \return boolean value
bool SupportLocalClient() const { return local_bypass_; }

/// \brief Return the base memory address if we attach to any shared memory.
auto SharedMemoryBaseAddr() const { return comm_->SharedMemoryBaseAddr(); }

/// Getter functions
session_id_type session_id() const { return cinfo_.session_id(); }
uint64_t getCacheMemSz() const { return cache_mem_sz_; }
bool isSpill() const { return spill_; }
const std::string &getHostname() const { return hostname_; }
int32_t getPort() const { return port_; }
int32_t getNumWorkers() const { return num_workers_; }
int32_t getPrefetchSize() const { return prefetch_size_; }

private:
mutable RWLock mux_;
uint64_t cache_mem_sz_;
bool spill_;
// The session_id_ and cache_crc_ work together to uniquely identify this particular cache and allow
// sharing of the cache.
uint32_t session_id_;
uint32_t cache_crc_;
CacheClientInfo cinfo_;
// The server_connection_id_ is the actual id we use for operations after the cache is built
connection_id_type server_connection_id_;
// Some magic cookie returned from the cache server.
std::string cookie_;
// Comm layer
bool local_bypass_;
std::string hostname_;
int32_t port_;
int32_t num_workers_;
int32_t prefetch_size_;
mutable std::shared_ptr<CacheClientGreeter> comm_;
};
} // namespace dataset
} // namespace mindspore


+ 90
- 0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_common.h View File

@@ -0,0 +1,90 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_

/// \note This header file contains common header files and some inlines used by
/// both client and server side codes. Do not put code that is not common here.
/// There are client and server specific header files.

// On platform like Windows, we may support only tcp/ip clients
#if !defined(_WIN32) && !defined(_WIN64)
#define CACHE_LOCAL_CLIENT 1
#endif

#ifdef CACHE_LOCAL_CLIENT
#include <sys/types.h>
#include <sys/ipc.h>
#include <sys/shm.h>
#else
typedef int key_t;
#endif
#ifdef ENABLE_CACHE
#include <grpcpp/grpcpp.h>
#endif
#include <string>
#ifdef ENABLE_CACHE
#include "proto/cache_grpc.grpc.pb.h"
#endif
#include "proto/cache_grpc.pb.h"
#include "minddata/dataset/engine/cache/cache_request.h"
#include "minddata/dataset/engine/cache/de_tensor_generated.h"
namespace mindspore {
namespace dataset {
/// \brief CacheRow and BatchFetch requests will switch to use shared memory method (if supported
/// on the platform) when the amount of bytes sent is greater than the following number.
/// For too small amount, we won't get any benefit using shared memory method because we need
/// two rpc requests to use shared memory method.
constexpr static int32_t kLocalByPassThreshold = 64 * 1024;
/// \brief A flag used by the BatchFetch request (client side) if it can support local bypass
constexpr static uint32_t kLocalClientSupport = 1;
/// \brief A flag used by CacheRow request (client side) and BatchFetch (server side) reply to indicate if the data is
/// inline in the protobuf. This also implies kLocalClientSupport is also true.
constexpr static uint32_t kDataIsInSharedMemory = 2;

/// \brief Convert a Status object into a protobuf
/// \param rc[in] Status object
/// \param reply[in/out] pointer to pre-allocated protobuf object
inline void Status2CacheReply(const Status &rc, CacheReply *reply) {
reply->set_rc(static_cast<google::int32>(rc.get_code()));
reply->set_msg(rc.ToString());
}

/// \brief Generate the unix socket file we use on both client/server side given a tcp/ip port number
/// \param port
/// \return unix socket url
inline std::string PortToUnixSocketPath(int port) { return "/tmp/cache_server_p" + std::to_string(port); }

/// \brief Generate a shared memory key using the tcp/ip port.
/// \note It must be called after the cache server generates the unix socket or ftok will fail.
/// \note Caller must check the return value. -1 means ftok failed.
/// \param[in] port
/// \param[out] err. If not null and ftok fails, this will contain the value of errno
/// \return key
inline key_t PortToFtok(int port, int *err) {
key_t shmkey = -1;
#ifdef CACHE_LOCAL_CLIENT
const std::string unix_path = PortToUnixSocketPath(port);
shmkey = ftok(unix_path.data(), 'a');
if (err != nullptr && shmkey == (key_t)-1) {
*err = errno;
}
#endif
return shmkey;
}
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_

+ 151
- 0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_fbb.cc View File

@@ -0,0 +1,151 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/cache_fbb.h"
namespace mindspore {
namespace dataset {
/// A private function used by SerializeTensorRowHeader to serialize each column in a tensor
/// \note Not to be called by outside world
/// \return Status object
Status SerializeOneTensorMeta(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb,
const std::shared_ptr<Tensor> &ts_ptr, flatbuffers::Offset<TensorMetaMsg> *out_off) {
RETURN_UNEXPECTED_IF_NULL(out_off);
const Tensor *ts = ts_ptr.get();
auto shape_off = fbb->CreateVector(ts->shape().AsVector());
const auto ptr = ts->GetBuffer();
if (ptr == nullptr) {
RETURN_STATUS_UNEXPECTED("Tensor buffer is null");
}
auto src = ts->type().value();
TensorType dest;
#define CASE(t) \
case DataType::t: \
dest = TensorType::TensorType_##t; \
break
// Map the type to fill in the flat buffer.
switch (src) {
CASE(DE_BOOL);
CASE(DE_INT8);
CASE(DE_UINT8);
CASE(DE_INT16);
CASE(DE_UINT16);
CASE(DE_INT32);
CASE(DE_UINT32);
CASE(DE_INT64);
CASE(DE_UINT64);
CASE(DE_FLOAT16);
CASE(DE_FLOAT32);
CASE(DE_FLOAT64);
CASE(DE_STRING);
default:
MS_LOG(ERROR) << "Unknown tensor. Dumping content:\n" << *ts;
RETURN_STATUS_UNEXPECTED("Unknown type");
}
#undef CASE

TensorMetaMsgBuilder ts_builder(*fbb);
ts_builder.add_dims(shape_off);
ts_builder.add_type(dest);
auto ts_off = ts_builder.Finish();
*out_off = ts_off;
return Status::OK();
}

Status SerializeTensorRowHeader(const TensorRow &row, std::shared_ptr<flatbuffers::FlatBufferBuilder> *out_fbb) {
RETURN_UNEXPECTED_IF_NULL(out_fbb);
auto fbb = std::make_shared<flatbuffers::FlatBufferBuilder>();
try {
fbb = std::make_shared<flatbuffers::FlatBufferBuilder>();
std::vector<flatbuffers::Offset<TensorMetaMsg>> v;
std::vector<int64_t> tensor_sz;
v.reserve(row.size());
tensor_sz.reserve(row.size());
// We will go through each column in the row.
for (const std::shared_ptr<Tensor> &ts_ptr : row) {
flatbuffers::Offset<TensorMetaMsg> ts_off;
RETURN_IF_NOT_OK(SerializeOneTensorMeta(fbb, ts_ptr, &ts_off));
v.push_back(ts_off);
tensor_sz.push_back(ts_ptr->SizeInBytes());
}
auto column_off = fbb->CreateVector(v);
auto data_sz_off = fbb->CreateVector(tensor_sz);
TensorRowHeaderMsgBuilder row_builder(*fbb);
row_builder.add_column(column_off);
row_builder.add_data_sz(data_sz_off);
// Pass the row_id even if it may not be known.
row_builder.add_row_id(row.getId());
row_builder.add_size_of_this(-1); // fill in later after we call Finish.
auto out = row_builder.Finish();
fbb->Finish(out);
// Now go back to fill in size_of_this in the flat buffer.
auto msg = GetMutableTensorRowHeaderMsg(fbb->GetBufferPointer());
auto success = msg->mutate_size_of_this(fbb->GetSize());
if (!success) {
RETURN_STATUS_UNEXPECTED("Unable to set size_of_this");
}
(*out_fbb) = std::move(fbb);
return Status::OK();
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
}
}

Status RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, std::shared_ptr<Tensor> *out) {
RETURN_UNEXPECTED_IF_NULL(col_ts);
auto shape_in = col_ts->dims();
auto type_in = col_ts->type();
std::vector<dsize_t> v;
v.reserve(shape_in->size());
v.assign(shape_in->begin(), shape_in->end());
TensorShape shape(v);
DataType::Type dest = DataType::DE_UNKNOWN;
#define CASE(t) \
case TensorType_##t: \
dest = DataType::Type::t; \
break

switch (type_in) {
CASE(DE_BOOL);
CASE(DE_INT8);
CASE(DE_UINT8);
CASE(DE_INT16);
CASE(DE_UINT16);
CASE(DE_INT32);
CASE(DE_UINT32);
CASE(DE_INT64);
CASE(DE_UINT64);
CASE(DE_FLOAT16);
CASE(DE_FLOAT32);
CASE(DE_FLOAT64);
CASE(DE_STRING);
}
#undef CASE

DataType type(dest);
std::shared_ptr<Tensor> ts;
RETURN_IF_NOT_OK(
Tensor::CreateFromMemory(shape, type, static_cast<const unsigned char *>(data.GetPointer()), data.GetSize(), &ts));
// Next we restore the real data which can be embedded or stored separately.
if (ts->SizeInBytes() != data.GetSize()) {
MS_LOG(ERROR) << "Unexpected length. Read " << data.GetSize() << ". Expected " << ts->SizeInBytes() << ".\n"
<< "Dumping tensor\n"
<< *ts << "\n";
RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details.");
}
*out = std::move(ts);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 46
- 0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_fbb.h View File

@@ -0,0 +1,46 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_FBB_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_FBB_H_

/// This header contains some serialize and deserialize functions for tensor row using
/// Google Flatbuffer

#include <memory>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/cache/de_tensor_generated.h"
#include "minddata/dataset/core/tensor_row.h"
#include "minddata/dataset/util/slice.h"
#include "minddata/dataset/util/status.h"

namespace mindspore {
namespace dataset {
/// \brief Function to serialize TensorRow header used by CacheRowRequest
/// \param row TensorRow
/// \param fbb [in/out] fbb that contains the serialized data
/// \return Status object
Status SerializeTensorRowHeader(const TensorRow &row, std::shared_ptr<flatbuffers::FlatBufferBuilder> *fbb);

/// \brief A function used by BatchFetchRequest to deserialize a flat buffer back to a tensor row.
/// \param col_ts A serialized version of Tensor meta data
/// \param data Tensor data wrapped in a slice
/// \param out Tensor
/// \return Status object
Status RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, std::shared_ptr<Tensor> *out);
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_FBB_H_

+ 54
- 0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc.proto View File

@@ -0,0 +1,54 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

syntax = "proto3";
package mindspore.dataset;
option cc_enable_arenas = true;

// The session_id and crc work together to uniquely identify this particular cache and allow
// sharing of the cache.
message CacheClientInfo {
uint32 session_id = 1;
uint32 crc = 2;
}

message CacheRequest {
// Type of rpc request
int32 type = 1;
// Extra optional flag used by individual request if needed
uint32 flag = 2;
oneof connect_info {
// The server_connection_id is the actual id we use for operations after the cache is built
int64 connection_id = 3;
// But some request like CreateCache we have to use the session id and crc to connect to the server.
CacheClientInfo connection_info = 4;
}
// Everything else is just vector of buffers
repeated bytes buf_data = 5;
}

message CacheReply {
int32 rc = 1;
string msg = 2;
// Extra optional flag used by individual request if needed
uint32 flag = 3;
// What the server send back is a plain buffer
bytes result = 4;
}

service CacheServerGreeter {
rpc CacheServerRequest (CacheRequest) returns (CacheReply) {}
}

+ 161
- 0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_client.cc View File

@@ -0,0 +1,161 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/cache_grpc_client.h"
#include <chrono>
namespace mindspore {
namespace dataset {
Status CacheClientRequestTag::MakeCall(CacheServerGreeter::Stub *stub, grpc::CompletionQueue *cq,
std::unique_ptr<CacheClientRequestTag> &&tag) {
// If there is anything extra we need to do before we send.
RETURN_IF_NOT_OK(tag->base_rq_->Prepare());
// One minute timeout
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60);
tag->ctx_.set_deadline(deadline);
tag->rpc_ = stub->PrepareAsyncCacheServerRequest(&tag->ctx_, tag->base_rq_->rq_, cq);
tag->rpc_->StartCall();
// Last step is we release the ownership and transfer it to the completion queue.
// The memory will be released by WorkerEntry or by the destructor when we drain the queue
auto ccReqTag = tag.release();
ccReqTag->rpc_->Finish(&ccReqTag->base_rq_->reply_, &ccReqTag->rc_,
ccReqTag); // inject this object into the completion queue
return Status::OK();
}

CacheClientGreeter::~CacheClientGreeter() {
(void)ServiceStop();
// Detach from shared memory if any
if (shmat_addr_ != nullptr) {
shmdt(shmat_addr_);
shmat_addr_ = nullptr;
}
}

CacheClientGreeter::CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_workers)
: num_workers_(num_workers), shm_key_(-1), shm_id_(-1), shmat_addr_(nullptr) {
grpc::ChannelArguments args;
// We need to bump up the message size to unlimited. The default receiving
// message limit is 4MB which is not big enough.
args.SetMaxReceiveMessageSize(-1);
#if CACHE_LOCAL_CLIENT
// Try connect locally to the unix_socket first as the first preference
// Need to resolve hostname to ip address rather than to do a string compare
if (hostname == "127.0.0.1") {
std::string target = "unix://" + PortToUnixSocketPath(port);
channel_ = grpc::CreateCustomChannel(target, grpc::InsecureChannelCredentials(), args);
} else {
#endif
std::string target = hostname + ":" + std::to_string(port);
channel_ = grpc::CreateCustomChannel(target, grpc::InsecureChannelCredentials(), args);
#if CACHE_LOCAL_CLIENT
}
#endif
stub_ = CacheServerGreeter::NewStub(channel_);
}

Status CacheClientGreeter::AttachToSharedMemory(int32_t port, bool *local_bypass) {
*local_bypass = false;
#if CACHE_LOCAL_CLIENT
int err;
shm_key_ = PortToFtok(port, &err);
if (shm_key_ == (key_t)-1) {
std::string errMsg = "Ftok failed with errno " + std::to_string(err);
RETURN_STATUS_UNEXPECTED(errMsg);
}
// Attach to the shared memory
shm_id_ = shmget(shm_key_, 0, 0);
if (shm_id_ == -1) {
RETURN_STATUS_UNEXPECTED("Shmget failed. Errno " + std::to_string(errno));
}
shmat_addr_ = shmat(shm_id_, nullptr, 0);
if (shmat_addr_ == reinterpret_cast<void *>(-1)) {
RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno));
}
*local_bypass = true;
#endif
return Status::OK();
}

Status CacheClientGreeter::DoServiceStart() {
RETURN_IF_NOT_OK(vg_.ServiceStart());
RETURN_IF_NOT_OK(DispatchWorkers(num_workers_));
return Status::OK();
}

Status CacheClientGreeter::DoServiceStop() {
// Shutdown the queue. We don't accept any more new incomers.
cq_.Shutdown();
// Shutdown the TaskGroup.
vg_.interrupt_all();
vg_.join_all(Task::WaitFlag::kNonBlocking);
// Drain the queue
bool success;
void *tag;
while (cq_.Next(&tag, &success)) {
auto r = reinterpret_cast<CacheClientRequestTag *>(tag);
delete r;
}
return Status::OK();
}

Status CacheClientGreeter::HandleRequest(std::shared_ptr<BaseRequest> rq) {
auto tag = std::make_unique<CacheClientRequestTag>(std::move(rq));
return tag->MakeCall(stub_.get(), &cq_, std::move(tag));
}

Status CacheClientGreeter::WorkerEntry() {
TaskManager::FindMe()->Post();
do {
bool success;
void *tag;
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(1);
// Set a timeout for one second. Check for interrupt if we need to do early exit.
auto r = cq_.AsyncNext(&tag, &success, deadline);
if (r == grpc_impl::CompletionQueue::NextStatus::GOT_EVENT) {
auto rq = reinterpret_cast<CacheClientRequestTag *>(tag);
if (success) {
auto &rc = rq->rc_;
if (!rc.ok()) {
auto error_code = rq->rc_.error_code();
std::string errMsg = rq->rc_.error_message() + ". GRPC Code " + std::to_string(error_code);
Status remote_rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
Status2CacheReply(remote_rc, &rq->base_rq_->reply_);
}
// Notify the waiting thread.
rq->Notify();
}
// We can now free the memory
delete rq;
} else if (r == grpc_impl::CompletionQueue::NextStatus::TIMEOUT) {
// If we are interrupted, exit. Otherwise wait again.
RETURN_IF_INTERRUPTED();
} else {
// Queue is drained.
break;
}
} while (true);
return Status::OK();
}

Status CacheClientGreeter::DispatchWorkers(int32_t num_workers) {
auto f = std::bind(&CacheClientGreeter::WorkerEntry, this);
for (auto i = 0; i < num_workers; ++i) {
RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Async reply", f));
}
return Status::OK();
}

} // namespace dataset
} // namespace mindspore

+ 102
- 0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_client.h View File

@@ -0,0 +1,102 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_

#include <memory>
#include <string>
#include <utility>
#include "minddata/dataset/engine/cache/cache_common.h"
#include "minddata/dataset/util/service.h"
#include "minddata/dataset/util/task_manager.h"
namespace mindspore {
namespace dataset {
/// \brief A client view of gRPC request
/// Like the class CacheServerRequest, this is used as a tag to inject into the gRPC
/// completion queue. The thread that makes the rpc request will wait on a wait post
/// area for the reply to come back. Since this tag will be deleted from memory and
/// we thus we need to work on a shared pointer of the BaseRequest such that its
/// use count is at least two. Otherwise either thread will be referencing stale memory.
/// \see CacheServerRequest
class CacheClientRequestTag {
public:
friend class CacheClientGreeter;
explicit CacheClientRequestTag(std::shared_ptr<BaseRequest> rq) : base_rq_(std::move(rq)) {}
~CacheClientRequestTag() = default;

/// \brief Make a RPC call
/// \param stub from CacheClientGreeter
/// \param cq from CacheClientGreeter
/// \return Status object
static Status MakeCall(CacheServerGreeter::Stub *stub, grpc::CompletionQueue *cq,
std::unique_ptr<CacheClientRequestTag> &&tag);

/// \brief Notify the client that a result has come back from the server
void Notify() { base_rq_->wp_.Set(); }

private:
std::shared_ptr<BaseRequest> base_rq_;
grpc::Status rc_;
grpc::ClientContext ctx_;
std::unique_ptr<grpc::ClientAsyncResponseReader<CacheReply>> rpc_;
};

/// \brief A GRPC layer to convert BaseRequest into protobuf and send to the cache server using gRPC
/// \see BaseRequest
class CacheClientGreeter : public Service {
friend class CacheClient;

public:
explicit CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_workers);
~CacheClientGreeter();

/// Override base Service class
Status DoServiceStart() override;
Status DoServiceStop() override;

/// \brief Send the request to the server
/// \return Status object
Status HandleRequest(std::shared_ptr<BaseRequest> rq);

/// \brief A handful of threads will be handling async reply from the server
/// \return
Status WorkerEntry();

/// \brief Kick off threads to receive reply from the server
Status DispatchWorkers(int32_t num_workers);

/// \brief Attach to shared memory for local client
/// \note Called after we have established a connection.
/// \return Status object.
Status AttachToSharedMemory(int32_t port, bool *local_bypass);

/// \brief This returns where we attach to the shared memory.
/// \return Base address of the shared memory.
const void *SharedMemoryBaseAddr() const { return shmat_addr_; }

private:
std::shared_ptr<grpc::Channel> channel_;
std::unique_ptr<CacheServerGreeter::Stub> stub_;
grpc::CompletionQueue cq_;
TaskGroup vg_;
int32_t num_workers_;
key_t shm_key_;
int32_t shm_id_;
void *shmat_addr_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_

+ 203
- 0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.cc View File

@@ -0,0 +1,203 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/cache_grpc_server.h"
#include <limits>
#include "minddata/dataset/engine/cache/cache_server.h"
#include "minddata/dataset/util/path.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace dataset {
CacheServerGreeterImpl::CacheServerGreeterImpl(int32_t port, int32_t shared_memory_sz_in_gb)
: port_(port), shm_pool_sz_in_gb_(shared_memory_sz_in_gb) {
// Setup a path for unix socket.
unix_socket_ = PortToUnixSocketPath(port);
// We can't generate the ftok key yet until the unix_socket_ is created
}

void CacheServerGreeterImpl::Shutdown() {
if (server_) {
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(1);
server_->Shutdown(deadline);
}
// Always shutdown the completion queue after the server.
if (cq_) {
cq_->Shutdown();
// We need to drain the queue. All the tag is coming from
// the Services pool which will be shutdown as well. So we
// ignore the tag.
void *tag;
bool success;
while (cq_->Next(&tag, &success)) {
}
}
}

CacheServerGreeterImpl::~CacheServerGreeterImpl() { Shutdown(); }

Status CacheServerGreeterImpl::IpcResourceCleanup() {
#if CACHE_LOCAL_CLIENT
int err;
auto shm_key = PortToFtok(port_, &err);
// We are expecting the unix path doesn't exist.
if (shm_key == (key_t)-1) {
return Status::OK();
}
// Attach to the shared memory
auto shm_id = shmget(shm_key, 0, 0);
if (shm_id == -1) {
return Status::OK();
}
struct shmid_ds ds {};
auto inx = shmctl(shm_id, IPC_STAT, &ds);
if (inx == -1) {
std::string errMsg = "Unable to query shared memory with id " + std::to_string(shm_id);
errMsg += "\nPlesae remove it manually using ipcrm -m command";
RETURN_STATUS_UNEXPECTED(errMsg);
}
if (ds.shm_nattch == 0) {
// Stale shared memory from last time.
// Remove both the memory and the socket path
inx = shmctl(shm_id, IPC_RMID, nullptr);
if (inx == -1) {
std::string errMsg = "Unable to remove shared memory with id " + std::to_string(shm_id);
errMsg += ". Errno :" + std::to_string(errno);
errMsg += "\nPlesae remove it manually using ipcrm -m command";
RETURN_STATUS_UNEXPECTED(errMsg);
}
Path p(unix_socket_);
(void)p.Remove();
} else {
// Server is already up.
MS_LOG(ERROR) << "Cache server is already up and running";
// We return a duplicate error. The main() will intercept
// and output a proper message
return Status(StatusCode::kDuplicateKey);
}
#endif
return Status::OK();
}

Status CacheServerGreeterImpl::Run() {
// To listen on all interfaces, use 0.0.0.0
// Use 127.0.0.1 if just locally on the same machine.
std::string host("0.0.0.0"); // listen on all interfaces.
std::string server_address = host + ":" + std::to_string(port_);
grpc::ServerBuilder builder;
// Default message size for gRPC is 4MB. Increase it to 2g-1
builder.SetMaxReceiveMessageSize(std::numeric_limits<int32_t>::max());
int port_tcpip = 0;
#if CACHE_LOCAL_CLIENT
int port_local = 0;
// Check if we need to do clean up on the shared memory if the server
// came down unexpectedly like SEGV
RETURN_IF_NOT_OK(IpcResourceCleanup());
// We also optimize on local clients on the same machine using unix socket
builder.AddListeningPort("unix://" + unix_socket_, grpc::InsecureServerCredentials(), &port_local);
#endif
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_tcpip);
builder.RegisterService(&svc_);
cq_ = builder.AddCompletionQueue();
server_ = builder.BuildAndStart();
if (server_) {
MS_LOG(INFO) << "Server listening on " << server_address;
#if CACHE_LOCAL_CLIENT
RETURN_IF_NOT_OK(CachedSharedMemoryArena::CreateArena(&shm_pool_, port_, shm_pool_sz_in_gb_));
MS_LOG(INFO) << "Creation of local socket and shared memory successful";
#endif
} else {
std::string errMsg = "Fail to start server. ";
if (port_tcpip != port_) {
errMsg += "Unable to bind to tcpip port " + std::to_string(port_) + ".";
}
#if CACHE_LOCAL_CLIENT
if (port_local == 0) {
errMsg += " Unable to create unix socket " + unix_socket_ + ".";
}
#endif
RETURN_STATUS_UNEXPECTED(errMsg);
}
return Status::OK();
}

Status CacheServerGreeterImpl::HandleRequest(int32_t worker_id) {
bool success;
void *tag;
// We loop through the grpc queue. Each connection if successful
// will come back with our own tag which is an instance of CacheServerRequest
// and we simply call its functor. But first we need to create these instances
// and inject them into the grpc queue.
CacheServerRequest *p;
// Get a free tag from my free list.
RETURN_IF_NOT_OK(CacheServer::GetFreeRequestTag(worker_id, &p));
RETURN_IF_NOT_OK((*p)(&svc_, cq_.get()));
do {
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(1);
// Set a timeout for one second. Check for interrupt if we need to do early exit.
auto r = cq_->AsyncNext(&tag, &success, deadline);
if (r == grpc_impl::CompletionQueue::NextStatus::GOT_EVENT) {
if (success) {
auto rq = static_cast<CacheServerRequest *>(tag);
RETURN_IF_NOT_OK((*rq)(&svc_, cq_.get()));
}
} else if (r == grpc_impl::CompletionQueue::NextStatus::TIMEOUT) {
// If we are interrupted, exit. Otherwise wait again.
RETURN_IF_INTERRUPTED();
} else {
// Queue is drained.
break;
}
} while (true);
return Status::OK();
}

Status CacheServerRequest::operator()(CacheServerGreeter::AsyncService *svc, grpc::ServerCompletionQueue *cq) {
auto myQID = getQid();
if (st_ == STATE::CREATE) {
st_ = STATE::PROCESS;
svc->RequestCacheServerRequest(&ctx_, &rq_, &responder_, cq, cq, this);
} else if (st_ == STATE::PROCESS) {
// Get a new tag and handle the next request before we serve the current request.
// The tag will be recycled when its state is changed to FINISH
CacheServerRequest *next_rq;
RETURN_IF_NOT_OK(CacheServer::GetFreeRequestTag(myQID, &next_rq));
RETURN_IF_NOT_OK((*next_rq)(svc, cq));
// Now we continue with the current request.
// First thing we need to extract the type from the incoming request.
// When this object was first created (i.e. STATE::CREATE), we set the type to UNKNOWN.
type_ = static_cast<RequestType>(rq_.type());
// Now we pass the address of this instance to CacheServer's main loop.
MS_LOG(DEBUG) << "Handle request " << *this;
auto &cs = CacheServer::GetInstance();
RETURN_IF_NOT_OK(cs.PushRequest(myQID, this));
} else if (st_ == STATE::FINISH) {
MS_LOG(DEBUG) << *this << " Finished.";
// Return back to the free list.
RETURN_IF_NOT_OK(CacheServer::ReturnRequestTag(this));
}
return Status::OK();
}

void CacheServerRequest::Print(std::ostream &out) const {
if (rq_.has_connection_info()) {
out << "Session Id: " << rq_.connection_info().session_id() << " CRC: " << rq_.connection_info().crc();
} else {
out << "Connection Id: " << rq_.connection_id();
}
out << " ";
BaseRequest::Print(out);
}
} // namespace dataset
} // namespace mindspore

+ 103
- 0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.h View File

@@ -0,0 +1,103 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_

#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/cache/cache_common.h"
#include "minddata/dataset/engine/cache/cache_arena.h"
#include "minddata/dataset/util/allocator.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/util/task_manager.h"

namespace mindspore {
namespace dataset {
/// \brief Server side view of BaseRequest. Incoming request are in the form of protobuf objects
/// and this class is used to translate from protobuf to structures understood by CacheService class.
/// \see CacheService
class CacheServerRequest : public BaseRequest {
public:
friend class CacheServer;
enum class STATE : int8_t { CREATE = 1, PROCESS = 2, FINISH = 3 };
explicit CacheServerRequest(int32_t queue_id)
: BaseRequest::BaseRequest(BaseRequest::RequestType::kRequestUnknown),
qid_(queue_id),
st_(STATE::CREATE),
responder_(&ctx_) {}

~CacheServerRequest() = default;

/// \brief Functor. Used mainly by CacheServerGreeterImpl class to tag each incoming request and this
/// functor will translate each protobuf into some form understood by by CacheService class.
/// \param svc Async service
/// \param cq Completion queue
/// \return Status object
Status operator()(CacheServerGreeter::AsyncService *svc, grpc::ServerCompletionQueue *cq);

/// \brief Override the base class Print method
/// \param out
void Print(std::ostream &out) const override;

/// \brief Getter of the queue id
/// \return The queue where the request should go to
int32_t getQid() const { return qid_; }

private:
int32_t qid_;
Status rc_;
STATE st_;
grpc::ServerContext ctx_;
grpc::ServerAsyncResponseWriter<CacheReply> responder_;
};

/// \brief Implementation of CacheServerGreeter
/// \note It is an async server
/// \see cache_grpc.proto
class CacheServerGreeterImpl final {
friend class CacheServer;

public:
explicit CacheServerGreeterImpl(int32_t port, int32_t shared_memory_sz_in_gb);
virtual ~CacheServerGreeterImpl();
/// \brief Brings up gRPC server
/// \return none
Status Run();
/// \brief Entry function to handle cache server request
Status HandleRequest(int32_t worker_id);

/// Return the shared memory pool.
/// \return Return the shared memory pool
CachedSharedMemoryArena *GetSharedMemoryPool() { return shm_pool_.get(); }

void Shutdown();

Status IpcResourceCleanup();

private:
int32_t port_;
size_t shm_pool_sz_in_gb_;
std::string unix_socket_;
CacheServerGreeter::AsyncService svc_;
std::unique_ptr<grpc::ServerCompletionQueue> cq_;
std::unique_ptr<grpc::Server> server_;
std::unique_ptr<CachedSharedMemoryArena> shm_pool_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_

+ 121
- 0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_main.cc View File

@@ -0,0 +1,121 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "minddata/dataset/engine/cache/cache_server.h"
#include <sys/types.h>
#include <unistd.h>
#ifdef USE_GLOG
#include <glog/logging.h>
#endif
#include <cstdlib>

namespace ds = mindspore::dataset;

int main(int argc, char **argv) {
ds::Status rc;
ds::CacheServer::Builder builder;

// This executable is not to be called directly, and should be invoked by cache_admin executable.
if (argc != 7) {
rc = ds::Status(ds::StatusCode::kSyntaxError);
std::cerr << rc.ToString() << std::endl;
return static_cast<int>(rc.get_code());
}

builder.SetRootDirectory(argv[1])
.SetNumWorkers(strtol(argv[2], nullptr, 10))
.SetPort(strtol(argv[3], nullptr, 10))
.SetSharedMemorySizeInGB(strtol(argv[4], nullptr, 10));

#ifdef USE_GLOG
FLAGS_minloglevel = strtol(argv[5], nullptr, 10);
#endif

auto daemonize_string = argv[6];
bool daemonize = strcmp(daemonize_string, "true") == 0 || strcmp(daemonize_string, "TRUE") == 0 ||
strcmp(daemonize_string, "t") == 0 || strcmp(daemonize_string, "T") == 0;

// We always change directory to / on unix rather than using the directory where the cache_server
// is called. This is a standard procedure for daemonize a process on unix.
if (chdir("/") == -1) {
std::string errMsg = "Unable to change directory to /. Errno = " + std::to_string(errno);
std::cerr << errMsg << std::endl;
return -1;
}

// Simple check of the parameters before we move on.
rc = builder.SanityCheck();
if (rc.IsError()) {
std::cerr << rc.ToString() << std::endl;
return static_cast<int>(rc.get_code());
}

#ifdef USE_GLOG
FLAGS_log_dir = "/tmp";
google::InitGoogleLogging(argv[0]);
#endif

if (daemonize) {
// fork the child process to become the daemon
pid_t pid = fork();
// failed to fork
if (pid < 0) {
std::string err_msg = "Failed to fork process for cache server: " + std::to_string(errno);
std::cerr << err_msg << std::endl;
return errno;
} else if (pid > 0) {
// Parent
std::cerr << "cache server daemon process has been created as process id: " << pid
<< "\nCheck log file for any start up error" << std::endl;
signal(SIGCHLD, SIG_IGN); // ignore sig child signal.
return 0;
} else {
// Child process will continue from here if daemonize and parent has already exited.
// If we are running in the foreground, none of the code in block below will be run.
pid_t sid;
umask(0);
sid = setsid();
if (sid < 0) {
MS_LOG(ERROR) << "Failed to setsid(). Errno = " << std::to_string(errno);
return errno;
}
close(0);
close(1);
close(2);
}
}

// Dump the summary
MS_LOG(INFO) << builder << std::endl;
rc = builder.Build();
if (rc.IsOk()) {
ds::CacheServer &cs = ds::CacheServer::GetInstance();
// Kick off the threads. Loop forever and never return unless error.
rc = cs.Run();
if (rc.get_code() == ds::StatusCode::kDuplicateKey) {
std::string errMsg = "Server is already started";
MS_LOG(ERROR) << errMsg;
std::cerr << errMsg << std::endl;
return 0;
}
}
if (rc.IsError()) {
MS_LOG(ERROR) << rc.ToString();
std::cerr << rc.ToString() << std::endl;
return static_cast<int>(rc.get_code());
}
return 0;
}

+ 175
- 145
mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc View File

@@ -14,154 +14,149 @@
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/cache_request.h"

#include <cstdlib>
#include <thread>
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/engine/cache/cache_fbb.h"
namespace mindspore {
namespace dataset {

Status CacheRowRequest::SerializeCacheRowRequest(const TensorRow &row) {
buffers_.reserve(row.size() + 1);
RETURN_IF_NOT_OK(SerializeTensorRowHeader(row));
buffers_.push_back(fbb_->GetBufferPointer());
for (const auto &ts : row) {
buffers_.push_back(ts->GetBuffer());
}
Status BaseRequest::Wait() {
RETURN_IF_NOT_OK(wp_.Wait());
Status remote_rc(static_cast<StatusCode>(reply_.rc()), reply_.msg());
RETURN_IF_NOT_OK(remote_rc);
// Any extra work to do before we return back to the client.
RETURN_IF_NOT_OK(PostReply());
return Status::OK();
}

Status CacheRowRequest::SerializeTensorRowHeader(const TensorRow &row) {
try {
fbb_ = std::make_shared<flatbuffers::FlatBufferBuilder>();
std::vector<flatbuffers::Offset<TensorMetaMsg>> v;
std::vector<int64_t> tensor_sz;
v.reserve(row.size());
tensor_sz.reserve(row.size());
// We will go through each column in the row.
for (const std::shared_ptr<Tensor> &ts_ptr : row) {
flatbuffers::Offset<TensorMetaMsg> ts_off;
RETURN_IF_NOT_OK(SerializeOneTensorMeta(ts_ptr, &ts_off));
v.push_back(ts_off);
tensor_sz.push_back(ts_ptr->SizeInBytes());
Status CacheRowRequest::SerializeCacheRowRequest(const CacheClient *cc, const TensorRow &row) {
CHECK_FAIL_RETURN_UNEXPECTED(row.size() > 0, "Empty tensor row");
CHECK_FAIL_RETURN_UNEXPECTED(cc->SupportLocalClient() == support_local_bypass_, "Local bypass mismatch");
// Calculate how many bytes (not counting the cookie) we are sending to the server. We only
// use shared memory (if supported) if we exceed certain amount
std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb;
RETURN_IF_NOT_OK(::mindspore::dataset::SerializeTensorRowHeader(row, &fbb));
sz_ += fbb->GetSize();
for (const auto &ts : row) {
sz_ += ts->SizeInBytes();
}
bool sent_using_local_bypass = support_local_bypass_ ? (sz_ >= kLocalByPassThreshold) : false;
uint32_t flag = 0;
if (support_local_bypass_) {
BitSet(&flag, kLocalClientSupport);
}
if (sent_using_local_bypass) {
BitSet(&flag, kDataIsInSharedMemory);
}
rq_.set_flag(flag);
if (sent_using_local_bypass) {
MS_LOG(DEBUG) << "Requesting " << sz_ << " bytes of shared memory data";
// Allocate shared memory from the server
auto mem_rq = std::make_shared<AllocateSharedBlockRequest>(rq_.connection_id(), sz_);
RETURN_IF_NOT_OK(cc->PushRequest(mem_rq));
RETURN_IF_NOT_OK(mem_rq->Wait());
addr_ = mem_rq->GetAddr();
// Now we need to add that to the base address of where we attach.
auto base = cc->SharedMemoryBaseAddr();
auto p = reinterpret_cast<void *>(reinterpret_cast<int64_t>(base) + addr_);
// Now we copy the data onto shared memory.
WritableSlice all(p, sz_);
auto offset = fbb->GetSize();
ReadableSlice header(fbb->GetBufferPointer(), fbb->GetSize());
Status copy_rc;
copy_rc = WritableSlice::Copy(&all, header);
if (copy_rc.IsOk()) {
for (const auto &ts : row) {
WritableSlice row_data(all, offset, ts->SizeInBytes());
ReadableSlice src(ts->GetBuffer(), ts->SizeInBytes());
copy_rc = WritableSlice::Copy(&row_data, src);
if (copy_rc.IsError()) {
break;
}
offset += ts->SizeInBytes();
}
// Fill in where to find the data
AddDataLocation();
}
auto column_off = fbb_->CreateVector(v);
auto data_sz_off = fbb_->CreateVector(tensor_sz);
TensorRowHeaderMsgBuilder row_builder(*fbb_);
row_builder.add_column(column_off);
row_builder.add_data_sz(data_sz_off);
// Pass the row_id even if it may not be known.
row_builder.add_row_id(row.getId());
row_builder.add_size_of_this(-1); // fill in later after we call Finish.
auto out = row_builder.Finish();
fbb_->Finish(out);
// Now go back to fill in size_of_this in the flat buffer.
auto msg = GetMutableTensorRowHeaderMsg(fbb_->GetBufferPointer());
auto success = msg->mutate_size_of_this(fbb_->GetSize());
if (!success) {
RETURN_STATUS_UNEXPECTED("Unable to set size_of_this");
if (copy_rc.IsError()) {
// We need to return the memory back to the server
auto mfree_req = GenerateFreeBlockRequest();
Status rc = cc->PushRequest(mfree_req);
// But we won't wait for the result for the sake of performance.
if (rc.IsError()) {
MS_LOG(ERROR) << "Push request for free memory failed.";
}
return copy_rc;
}
return Status::OK();
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
} else {
// We have already filled the first buffer which is the cookie.
sz_ += rq_.buf_data(0).size();
rq_.add_buf_data(fbb->GetBufferPointer(), fbb->GetSize());
for (const auto &ts : row) {
rq_.add_buf_data(ts->GetBuffer(), ts->SizeInBytes());
}
MS_LOG(DEBUG) << "Sending " << sz_ << " bytes of tensor data in " << rq_.buf_data_size() << " segments";
}
return Status::OK();
}

Status CacheRowRequest::SerializeOneTensorMeta(const std::shared_ptr<Tensor> &ts_ptr,
flatbuffers::Offset<TensorMetaMsg> *out_off) {
RETURN_UNEXPECTED_IF_NULL(out_off);
const Tensor *ts = ts_ptr.get();
auto shape_off = fbb_->CreateVector(ts->shape().AsVector());
const auto ptr = ts->GetBuffer();
if (ptr == nullptr) {
RETURN_STATUS_UNEXPECTED("Tensor buffer is null");
}
auto src = ts->type().value();
TensorType dest;
#define CASE(t) \
case DataType::t: \
dest = TensorType::TensorType_##t; \
break
// Map the type to fill in the flat buffer.
switch (src) {
CASE(DE_BOOL);
CASE(DE_INT8);
CASE(DE_UINT8);
CASE(DE_INT16);
CASE(DE_UINT16);
CASE(DE_INT32);
CASE(DE_UINT32);
CASE(DE_INT64);
CASE(DE_UINT64);
CASE(DE_FLOAT16);
CASE(DE_FLOAT32);
CASE(DE_FLOAT64);
CASE(DE_STRING);
default:
MS_LOG(ERROR) << "Unknown tensor. Dumping content:\n" << *ts;
RETURN_STATUS_UNEXPECTED("Unknown type");
Status CacheRowRequest::PostReply() {
if (!reply_.result().empty()) {
row_id_from_server_ = strtoll(reply_.result().data(), nullptr, 10);
}
#undef CASE

TensorMetaMsgBuilder ts_builder(*fbb_);
ts_builder.add_dims(shape_off);
ts_builder.add_type(dest);
auto ts_off = ts_builder.Finish();
*out_off = ts_off;
return Status::OK();
}

Status BatchFetchRequest::RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data,
std::shared_ptr<Tensor> *out) {
RETURN_UNEXPECTED_IF_NULL(col_ts);
auto shape_in = col_ts->dims();
auto type_in = col_ts->type();
std::vector<dsize_t> v;
v.reserve(shape_in->size());
v.assign(shape_in->begin(), shape_in->end());
TensorShape shape(v);
DataType::Type dest = DataType::DE_UNKNOWN;
#define CASE(t) \
case TensorType_##t: \
dest = DataType::Type::t; \
break

switch (type_in) {
CASE(DE_BOOL);
CASE(DE_INT8);
CASE(DE_UINT8);
CASE(DE_INT16);
CASE(DE_UINT16);
CASE(DE_INT32);
CASE(DE_UINT32);
CASE(DE_INT64);
CASE(DE_UINT64);
CASE(DE_FLOAT16);
CASE(DE_FLOAT32);
CASE(DE_FLOAT64);
CASE(DE_STRING);
Status CacheRowRequest::Prepare() {
if (BitTest(rq_.flag(), kDataIsInSharedMemory)) {
// First one is cookie, followed by address and then size.
CHECK_FAIL_RETURN_UNEXPECTED(rq_.buf_data_size() == 3, "Incomplete rpc data");
} else {
// First one is cookie. 2nd one is the google flat buffers followed by a number of buffers.
// But we are not going to decode them to verify.
CHECK_FAIL_RETURN_UNEXPECTED(rq_.buf_data_size() >= 3, "Incomplete rpc data");
}
#undef CASE

DataType type(dest);
std::shared_ptr<Tensor> ts;
RETURN_IF_NOT_OK(
Tensor::CreateFromMemory(shape, type, static_cast<const unsigned char *>(data.GetPointer()), data.GetSize(), &ts));
// Next we restore the real data which can be embedded or stored separately.
if (ts->SizeInBytes() != data.GetSize()) {
MS_LOG(ERROR) << "Unexpected length. Read " << data.GetSize() << ". Expected " << ts->SizeInBytes() << ".\n"
<< "Dumping tensor\n"
<< *ts << "\n";
RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details.");
}
*out = std::move(ts);
return Status::OK();
}

Status BatchFetchRequest::RestoreRows(TensorTable *out) {
BatchFetchRequest::BatchFetchRequest(connection_id_type connection_id, const std::vector<row_id_type> &row_id,
bool local_bypass)
: BaseRequest(RequestType::kBatchFetchRows), support_local_bypass_(local_bypass), row_id_(row_id) {
rq_.set_connection_id(connection_id);
rq_.set_flag(support_local_bypass_ ? kLocalClientSupport : 0);
// Convert the row id into a flatbuffer
flatbuffers::FlatBufferBuilder fbb;
auto off_t = fbb.CreateVector(row_id);
TensorRowIdsBuilder bld(fbb);
bld.add_row_id(off_t);
auto off = bld.Finish();
fbb.Finish(off);
rq_.add_buf_data(fbb.GetBufferPointer(), fbb.GetSize());
}

Status BatchFetchRequest::RestoreRows(TensorTable *out, const void *baseAddr, int64_t *out_addr) {
RETURN_UNEXPECTED_IF_NULL(out);
auto num_elements = row_id_.size();
auto *offset_array = reinterpret_cast<const int64_t *>(mem_.GetPointer());
const char *ptr = nullptr;
int64_t sz = 0;
// Tap into the reply flag to see where we can find the data. Server may decide the amount is
// so small that it doesn't use shared memory method.
auto flag = reply_.flag();
bool dataOnSharedMemory = support_local_bypass_ ? (BitTest(flag, kDataIsInSharedMemory)) : false;
if (dataOnSharedMemory) {
auto addr = strtoll(reply_.result().data(), nullptr, 10);
ptr = reinterpret_cast<const char *>(reinterpret_cast<int64_t>(baseAddr) + addr);
RETURN_UNEXPECTED_IF_NULL(out);
*out_addr = addr;
} else {
ptr = reply_.result().data();
*out_addr = -1;
}
auto *offset_array = reinterpret_cast<const int64_t *>(ptr);
sz = offset_array[num_elements];
CHECK_FAIL_RETURN_UNEXPECTED(support_local_bypass_ || sz == reply_.result().length(), "Length mismatch");
TensorTable tbl;
tbl.reserve(num_elements);
ReadableSlice all(mem_.GetPointer(), mem_.GetSizeInBytes());
ReadableSlice all(ptr, sz);
for (auto i = 0; i < num_elements; ++i) {
auto len = offset_array[i + 1] - offset_array[i];
TensorRow row;
@@ -178,10 +173,12 @@ Status BatchFetchRequest::RestoreRows(TensorTable *out) {
auto col_ts = msg->column()->Get(k);
std::shared_ptr<Tensor> ts;
ReadableSlice data(row_data, ts_offset, msg->data_sz()->Get(k));
RETURN_IF_NOT_OK(RestoreOneTensor(col_ts, data, &ts));
RETURN_IF_NOT_OK(mindspore::dataset::RestoreOneTensor(col_ts, data, &ts));
row.push_back(ts);
ts_offset += data.GetSize();
}
} else {
CHECK_FAIL_RETURN_UNEXPECTED(len == 0, "Data corruption detected.");
}
tbl.push_back(std::move(row));
}
@@ -189,36 +186,69 @@ Status BatchFetchRequest::RestoreRows(TensorTable *out) {
return Status::OK();
}

CreateCacheRequest::CreateCacheRequest(const CacheClientInfo &cinfo, uint64_t cache_mem_sz,
CreateCacheRequest::CreateCacheFlag flag)
: BaseRequest(RequestType::kCreateCache), cache_mem_sz_(cache_mem_sz), flag_(flag) {
// Type has been set already in the base constructor. So we need to fill in the connection info.
// On successful return, we will get the connection id
rq_.mutable_connection_info()->operator=(cinfo);
}

Status CreateCacheRequest::Prepare() {
try {
flatbuffers::FlatBufferBuilder fbb;
CreateCacheRequestMsgBuilder bld(fbb);
bld.add_cache_mem_sz(cache_mem_sz_);
bld.add_flag(static_cast<uint32_t>(flag_));
auto off = bld.Finish();
fbb.Finish(off);
rq_.add_buf_data(fbb.GetBufferPointer(), fbb.GetSize());
return Status::OK();
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
}
}

Status CacheSchemaRequest::SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map) {
try {
fbb_ = std::make_shared<flatbuffers::FlatBufferBuilder>();
flatbuffers::FlatBufferBuilder fbb;
std::vector<flatbuffers::Offset<ColumnNameMsg>> v;
v.reserve(map.size());
for (auto &column : map) {
auto c = CreateColumnNameMsg(*fbb_, fbb_->CreateString(column.first), column.second);
auto c = CreateColumnNameMsg(fbb, fbb.CreateString(column.first), column.second);
v.push_back(c);
}
auto v_off = fbb_->CreateVector(v);
auto final_off = CreateSchemaMsg(*fbb_, v_off);
fbb_->Finish(final_off);
buf_ = fbb_->GetBufferPointer();
len_of_buf_ = fbb_->GetSize();
auto v_off = fbb.CreateVector(v);
auto final_off = CreateSchemaMsg(fbb, v_off);
fbb.Finish(final_off);
rq_.add_buf_data(fbb.GetBufferPointer(), fbb.GetSize());
return Status::OK();
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
}
}

std::unordered_map<std::string, int32_t> FetchSchemaRequest::GetColumnMap() {
if (column_name_id_map_.empty()) {
auto *map_msg = flatbuffers::GetRoot<SchemaMsg>(mem_.GetPointer());
auto v = map_msg->column();
for (auto i = 0; i < v->size(); ++i) {
auto col = map_msg->column()->Get(i);
column_name_id_map_.emplace(col->name()->str(), col->id());
}
Status FetchSchemaRequest::PostReply() {
auto *map_msg = flatbuffers::GetRoot<SchemaMsg>(reply_.result().data());
auto v = map_msg->column();
for (auto i = 0; i < v->size(); ++i) {
auto col = map_msg->column()->Get(i);
column_name_id_map_.emplace(col->name()->str(), col->id());
}
return column_name_id_map_;
return Status::OK();
}

std::unordered_map<std::string, int32_t> FetchSchemaRequest::GetColumnMap() { return column_name_id_map_; }

Status GetStatRequest::PostReply() {
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(reply_.result().data());
stat_.num_disk_cached = msg->num_disk_cached();
stat_.num_mem_cached = msg->num_mem_cached();
stat_.avg_cache_sz = msg->avg_cache_sz();
stat_.max_row_id = msg->max_row_id();
stat_.min_row_id = msg->min_row_id();
stat_.cache_service_state = msg->state();
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 207
- 85
mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h View File

@@ -18,11 +18,16 @@

#include <algorithm>
#include <memory>
#include <iostream>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

#ifdef ENABLE_CACHE
#include "proto/cache_grpc.grpc.pb.h"
#endif
#include "proto/cache_grpc.pb.h"
#include "minddata/dataset/core/tensor_row.h"
#include "minddata/dataset/engine/cache/de_tensor_generated.h"
#include "minddata/dataset/util/slice.h"
@@ -30,6 +35,17 @@

namespace mindspore {
namespace dataset {
class CacheClient;
/// \brief Statistic structure for GetStat request
struct CacheServiceStat {
int64_t num_mem_cached;
int64_t num_disk_cached;
int64_t avg_cache_sz;
row_id_type min_row_id;
row_id_type max_row_id;
int8_t cache_service_state;
};

/// \brief CacheClient communicates with CacheServer using Requests.
class BaseRequest {
public:
@@ -44,195 +60,301 @@ class BaseRequest {
kCacheSchema = 6,
kFetchSchema = 7,
kBuildPhaseDone = 8,
kDropSession = 9,
kGenerateSessionId = 10,
kAllocateSharedBlock = 11,
kFreeSharedBlock = 12,
kStopService = 13,
// Add new request before it.
kRequestUnknown = 32767
};
// For kCreateCache
enum class CreateCacheFlag : uint32_t { kNone = 0, kSpillToDisk = 1, kGenerateRowId = 1u << 1L };

friend class CacheServer;
friend class CacheServerRequest;
friend class CacheClientGreeter;
friend class CacheClientRequestTag;

/// \brief Base class of a cache server request
/// \param connection_id A combination of session id and crc that uniquely identifies a connection.
/// \param type Type of the request
explicit BaseRequest(connection_id_type connection_id, RequestType type)
: type_(type), connection_id_(connection_id) {}
explicit BaseRequest(RequestType type) : type_(type) { rq_.set_type(static_cast<google::int32>(type_)); }
virtual ~BaseRequest() = default;
/// \brief Wait for the completion of a request
/// \return Status returned from the cache server
Status Wait() {
RETURN_IF_NOT_OK(wp_.Wait());
return rc_;

/// \brief A print method for debugging
/// \param out The output stream to write output to
virtual void Print(std::ostream &out) const { out << "Request type: " << static_cast<int16_t>(type_); }

/// \brief << Stream output operator overload
/// \param out reference to the output stream
/// \param rq reference to the BaseRequest
/// \return the output stream
friend std::ostream &operator<<(std::ostream &out, const BaseRequest &rq) {
rq.Print(out);
return out;
}

/// \brief Getter function of the current connection id
/// \return Connection id
connection_id_type GetServerConnectionId() const { return connection_id_; }
/// \brief Derived class can implement extra work to be done before the request is sent to the server
virtual Status Prepare() { return Status::OK(); }

/// \brief Derived class can implement extra work to be done after the server sends the request
virtual Status PostReply() { return Status::OK(); }

/// \brief A method for the client to wait for the availability of the result back from the server.
/// \return Status object
Status Wait();

protected:
CacheRequest rq_; // This is what we send to the server
CacheReply reply_; // This is what the server send back

private:
RequestType type_;
connection_id_type connection_id_;
Status rc_;
WaitPost wp_;
WaitPost wp_; // A sync area used by the client side.
};

class FreeSharedBlockRequest : public BaseRequest {
public:
friend class CacheServer;
explicit FreeSharedBlockRequest(connection_id_type connection_id, int64_t addr)
: BaseRequest(RequestType::kFreeSharedBlock) {
rq_.set_connection_id(connection_id);
rq_.add_buf_data(std::to_string(addr));
}
~FreeSharedBlockRequest() = default;
};

/// \brief Request to cache a single TensorRow
class CacheRowRequest : public BaseRequest {
public:
friend class CacheServer;
explicit CacheRowRequest(connection_id_type connection_id, const std::string &cookie)
: BaseRequest(connection_id, RequestType::kCacheRow), row_id_from_server_(-1), cookie_(cookie) {}
friend class CacheClient;
explicit CacheRowRequest(connection_id_type connection_id, const std::string &cookie, bool local_bypass)
: BaseRequest(RequestType::kCacheRow),
support_local_bypass_(local_bypass),
addr_(-1),
sz_(0),
row_id_from_server_(-1) {
rq_.set_connection_id(connection_id);
rq_.add_buf_data(cookie);
}
~CacheRowRequest() = default;

/// \brief Serialize a TensorRow for streaming to the cache server
/// \param row TensorRow
/// \return Status object
Status SerializeCacheRowRequest(const TensorRow &row);
Status SerializeCacheRowRequest(const CacheClient *cc, const TensorRow &row);

/// \brief Sanity check before we send the row.
/// \return Status object
Status Prepare() override;

/// \brief Override the base function get the row id returned from the server
/// \return Status object
Status PostReply() override;

/// \brief Return the row id assigned to this row for non-mappable dataset
/// \return row id of the cached row
row_id_type GetRowIdAfterCache() { return row_id_from_server_; }

/// \brief If we are doing local bypass, fill in extra request information of where the data is located.
void AddDataLocation() {
if (support_local_bypass_) {
rq_.add_buf_data(std::to_string(addr_));
rq_.add_buf_data(std::to_string(sz_));
}
}

/// \brief If we fail to send the data to the server using shared memory method, we should release
/// the shared memory by sending another request. The following function will generate a suitable
/// request for the CacheClient to send.
std::shared_ptr<FreeSharedBlockRequest> GenerateFreeBlockRequest() {
return std::make_shared<FreeSharedBlockRequest>(rq_.connection_id(), addr_);
}

private:
std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb_;
bool support_local_bypass_;
int64_t addr_;
int64_t sz_;
row_id_type row_id_from_server_;
std::vector<const void *> buffers_;
std::string cookie_;

/// \brief Private function to serialize one TensorRow
/// \param row TensorRow
/// \return Status object
Status SerializeTensorRowHeader(const TensorRow &row);
/// \brief Private function to serialize one Tensor
/// \param ts_ptr Tensor
/// \return Status object
Status SerializeOneTensorMeta(const std::shared_ptr<Tensor> &ts_ptr, flatbuffers::Offset<TensorMetaMsg> *out_off);
};

/// \brief Request to fetch rows in batch
class BatchFetchRequest : public BaseRequest {
public:
friend class CacheServer;
friend class CacheService;
BatchFetchRequest(connection_id_type connection_id, const std::vector<row_id_type> &row_id)
: BaseRequest(connection_id, RequestType::kBatchFetchRows), row_id_(row_id) {}
BatchFetchRequest(connection_id_type connection_id, const std::vector<row_id_type> &row_id, bool local_bypass);
~BatchFetchRequest() = default;
Status RestoreRows(TensorTable *out);
Status RestoreRows(TensorTable *out, const void *baseAddr, int64_t *out_addr);

private:
bool support_local_bypass_;
std::vector<row_id_type> row_id_;
MemGuard<uint8_t> mem_;
Status RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, std::shared_ptr<Tensor> *out);
};

/// \brief Request to create a cache for the current connection
class CreationCacheRequest : public BaseRequest {
class CreateCacheRequest : public BaseRequest {
public:
friend class CacheServer;
enum class CreateCacheFlag : uint32_t { kNone = 0, kSpillToDisk = 1, kGenerateRowId = 1u << 1L };

/// \brief Constructor
/// \param connection_id
/// \param cache_mem_sz Maximum memory assigned for this connection. 0 means unlimited
/// \param flag Attributes of the cache.
explicit CreationCacheRequest(connection_id_type connection_id, uint64_t cache_mem_sz,
CreateCacheFlag flag = CreateCacheFlag::kNone)
: BaseRequest(connection_id, RequestType::kCreateCache), cache_mem_sz(cache_mem_sz), flag_(flag) {}

~CreationCacheRequest() = default;
explicit CreateCacheRequest(const CacheClientInfo &cinfo, uint64_t cache_mem_sz,
CreateCacheFlag flag = CreateCacheFlag::kNone);
~CreateCacheRequest() = default;
void ParseResult(connection_id_type *id, std::string *out) {
auto p = flatbuffers::GetRoot<CreateCacheReplyMsg>(reply_.result().data());
*id = p->connection_id();
*out = p->cookie()->str();
}

std::string cookie() const { return cookie_; }
/// Overload the base class Prepare
Status Prepare() override;

private:
uint64_t cache_mem_sz;
uint64_t cache_mem_sz_;
CreateCacheFlag flag_;
std::string cookie_;
};

/// \brief Request to purge a cache.
class PurgeCacheRequest : public BaseRequest {
public:
friend class CacheServer;
explicit PurgeCacheRequest(connection_id_type connection_id) : BaseRequest(connection_id, RequestType::kPurgeCache) {}

explicit PurgeCacheRequest(connection_id_type connection_id) : BaseRequest(RequestType::kPurgeCache) {
rq_.set_connection_id(connection_id);
}
~PurgeCacheRequest() = default;
};

/// \brief Request to destroy a cache
class DestroyCacheRequest : public BaseRequest {
public:
friend class CacheServer;
explicit DestroyCacheRequest(connection_id_type connection_id)
: BaseRequest(connection_id, RequestType::kDestroyCache) {}

/// \brief Destructor
explicit DestroyCacheRequest(connection_id_type connection_id) : BaseRequest(RequestType::kDestroyCache) {
rq_.set_connection_id(connection_id);
}
~DestroyCacheRequest() = default;
};

/// \brief Obtain the statistics of the current connection
class GetStatRequest : public BaseRequest {
public:
friend class CacheServer;
friend class CacheService;
explicit GetStatRequest(connection_id_type connection_id) : BaseRequest(connection_id, RequestType::kGetStat) {}
explicit GetStatRequest(connection_id_type connection_id) : BaseRequest(RequestType::kGetStat) {
rq_.set_connection_id(connection_id);
}

~GetStatRequest() = default;

row_id_type GetMinRowId() const {
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer());
return msg->min_row_id();
}
row_id_type GetMaxRowId() const {
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer());
return msg->max_row_id();
}
int64_t GetNumMemCached() const {
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer());
return msg->num_mem_cached();
}
int64_t GetNumDiskCached() const {
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer());
return msg->num_disk_cached();
}
uint8_t GetState() const {
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer());
return msg->state();
/// \brief Override base function to process the result.
Status PostReply() override;

void GetStat(CacheServiceStat *stat) {
if (stat != nullptr) {
(*stat) = stat_;
}
}

private:
MemGuard<uint8_t> mem_;
CacheServiceStat stat_{};
};

/// \brief Request to cache a schema
class CacheSchemaRequest : public BaseRequest {
public:
friend class CacheServer;
explicit CacheSchemaRequest(connection_id_type connection_id)
: BaseRequest(connection_id, RequestType::kCacheSchema), buf_(nullptr), len_of_buf_(0) {}
explicit CacheSchemaRequest(connection_id_type connection_id) : BaseRequest(RequestType::kCacheSchema) {
rq_.set_connection_id(connection_id);
}
~CacheSchemaRequest() = default;

Status SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map);
const void *GetBuffer() const { return buf_; }

private:
std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb_;
const void *buf_;
int64_t len_of_buf_;
};

/// \brief Request to fetch a schema
class FetchSchemaRequest : public BaseRequest {
public:
friend class CacheServer;
explicit FetchSchemaRequest(connection_id_type connection_id)
: BaseRequest(connection_id, RequestType::kFetchSchema) {}
explicit FetchSchemaRequest(connection_id_type connection_id) : BaseRequest(RequestType::kFetchSchema) {
rq_.set_connection_id(connection_id);
}
~FetchSchemaRequest() = default;

Status PostReply() override;

std::unordered_map<std::string, int32_t> GetColumnMap();

private:
MemGuard<uint8_t> mem_;
std::unordered_map<std::string, int32_t> column_name_id_map_;
};

/// \brief Request to change a cache from build phase to read phase. Applies to non-mappable cache only.
class BuildPhaseDoneRequest : public BaseRequest {
public:
friend class CacheServer;
BuildPhaseDoneRequest(connection_id_type connection_id, const std::string &cookie)
: BaseRequest(connection_id, RequestType::kBuildPhaseDone), cookie_(cookie) {}

: BaseRequest(RequestType::kBuildPhaseDone), cookie_(cookie) {
rq_.set_connection_id(connection_id);
rq_.add_buf_data(cookie_);
}
~BuildPhaseDoneRequest() = default;

private:
std::string cookie_;
};

/// \brief Request to drop all the caches in the current session
class DropSessionRequest : public BaseRequest {
public:
friend class CacheServer;
explicit DropSessionRequest(const CacheClientInfo &cinfo) : BaseRequest(RequestType::kDropSession) {
rq_.mutable_connection_info()->operator=(cinfo);
}
~DropSessionRequest() = default;
};

class GenerateSessionIdRequest : public BaseRequest {
public:
friend class CacheServer;
GenerateSessionIdRequest() : BaseRequest(RequestType::kGenerateSessionId) {
// We don't have anything client info nor connection id to send. But we will manually
// set the connection id to 0.
rq_.set_connection_id(0);
}

~GenerateSessionIdRequest() = default;

session_id_type GetSessionId() { return atoi(reply_.result().data()); }
};

class AllocateSharedBlockRequest : public BaseRequest {
public:
friend class CacheServer;
explicit AllocateSharedBlockRequest(connection_id_type connection_id, size_t requestedSz)
: BaseRequest(RequestType::kAllocateSharedBlock) {
rq_.set_connection_id(connection_id);
rq_.add_buf_data(std::to_string(requestedSz));
}
~AllocateSharedBlockRequest() = default;

/// \brief On return from the server, we get the (relative) address where
/// the free block is located.
/// \return
int64_t GetAddr() {
auto addr = strtoll(reply_.result().data(), nullptr, 10);
return addr;
}
};

class ShutdownRequest : public BaseRequest {
public:
friend class CacheServer;
ShutdownRequest() : BaseRequest(RequestType::kStopService) {}
~ShutdownRequest() = default;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVICE_H_

+ 550
- 123
mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc View File

@@ -14,25 +14,89 @@
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/cache_server.h"
#include <algorithm>
#include <functional>
#include <limits>
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/engine/cache/cache_service.h"
#include "minddata/dataset/engine/cache/cache_request.h"
#include "minddata/dataset/util/bit.h"
#include "minddata/dataset/util/path.h"
#include "minddata/dataset/util/random.h"
#ifdef CACHE_LOCAL_CLIENT
#include "minddata/dataset/util/sig_handler.h"
#endif

namespace mindspore {
namespace dataset {
CacheServer *CacheServer::instance_ = nullptr;
std::once_flag CacheServer::init_instance_flag_;
Status CacheServer::DoServiceStart() {
#ifdef CACHE_LOCAL_CLIENT
// We need to destroy the shared memory if user hits Control-C
RegisterHandlers();
#endif
if (!top_.empty()) {
Path spill(top_);
RETURN_IF_NOT_OK(spill.CreateDirectories());
MS_LOG(INFO) << "CacheServer will use disk folder: " << top_;
}
RETURN_IF_NOT_OK(vg_.ServiceStart());
cache_q_ = std::make_shared<Queue<BaseRequest *>>(1024);
// There will be num_workers_ threads working on the grpc queue and
// the same number of threads working on the CacheServerRequest queue.
// Like a connector object we will set up the same number of queues but
// we do not need to preserve any order. We will set the capacity of
// each queue to be 128 since we are just pushing memory pointers which
// is only 8 byte each.
const int32_t que_capacity = 128;
// This is the request queue from the client
cache_q_ = std::make_shared<QueueList<CacheServerRequest *>>();
cache_q_->Init(num_workers_, que_capacity);
// For the grpc completion queue to work, we need to allocate some
// tags which in our case are instances of CacheServerQuest.
// They got recycled and we will allocate them in advance and push
// them into some free list. We need more (two or three times) the
// size of the cache_q. While each worker is working on a CacheSerRequest,
// we need some extra running injecting in the the qrpc completion queue.
const int32_t multiplier = 3;
const int32_t free_list_capacity = multiplier * (que_capacity + 1);
free_list_ = std::make_shared<QueueList<CacheServerRequest *>>();
free_list_->Init(num_workers_, free_list_capacity);
// We need to have a reference to the services memory pool in case
// the Services goes out of scope earlier than us since it is a singleton
mp_ = Services::GetInstance().GetServiceMemPool();
Allocator<CacheServerRequest> alloc(mp_);
tag_.reserve(num_workers_);
// Now we populate all free list.
for (auto m = 0; m < num_workers_; ++m) {
// Ideally we allocate all the free list in one malloc. But it turns out it exceeds the
// Arena size. So we will we will allocate one segment at a time.
auto my_tag = std::make_unique<MemGuard<CacheServerRequest, Allocator<CacheServerRequest>>>(alloc);
// Allocate the tag and assign it the current queue
RETURN_IF_NOT_OK(my_tag->allocate(free_list_capacity, m));
for (int i = 0; i < free_list_capacity; ++i) {
RETURN_IF_NOT_OK(free_list_->operator[](m)->Add((*my_tag)[i]));
}
tag_.push_back(std::move(my_tag));
}
RETURN_IF_NOT_OK(cache_q_->Register(&vg_));
auto f = std::bind(&CacheServer::ServerRequest, this);
// Spawn a a few threads to serve the request.
RETURN_IF_NOT_OK(free_list_->Register(&vg_));
// Spawn a few threads to serve the real request.
auto f = std::bind(&CacheServer::ServerRequest, this, std::placeholders::_1);
for (auto i = 0; i < num_workers_; ++i) {
RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Cache service worker", std::bind(f, i)));
}
// Start the comm layer
try {
comm_layer_ = std::make_shared<CacheServerGreeterImpl>(port_, shared_memory_sz_in_gb_);
RETURN_IF_NOT_OK(comm_layer_->Run());
} catch (const std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
}
// Finally loop forever to handle the request.
auto r = std::bind(&CacheServer::RpcRequest, this, std::placeholders::_1);
for (auto i = 0; i < num_workers_; ++i) {
RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Cache server", f));
RETURN_IF_NOT_OK(vg_.CreateAsyncTask("rpc worker", std::bind(r, i)));
}
return Status::OK();
}
@@ -65,188 +129,551 @@ CacheService *CacheServer::GetService(connection_id_type id) const {
return nullptr;
}

Status CacheServer::CreateService(connection_id_type connection_id, uint64_t cache_mem_sz,
BaseRequest::CreateCacheFlag flag, std::string *out_cookie) {
Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) {
CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing connection info");
std::string cookie;
auto session_id = rq->connection_info().session_id();
auto crc = rq->connection_info().crc();
// We concat both numbers to form the internal connection id.
auto connection_id = GetConnectionID(session_id, crc);
CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing info to create cache");
auto &create_cache_buf = rq->buf_data(0);
auto p = flatbuffers::GetRoot<CreateCacheRequestMsg>(create_cache_buf.data());
auto flag = static_cast<CreateCacheRequest::CreateCacheFlag>(p->flag());
auto cache_mem_sz = p->cache_mem_sz();
// We can't do spilling unless this server is setup with a spill path in the first place
bool spill = (flag & BaseRequest::CreateCacheFlag::kSpillToDisk) == BaseRequest::CreateCacheFlag::kSpillToDisk;
bool spill =
(flag & CreateCacheRequest::CreateCacheFlag::kSpillToDisk) == CreateCacheRequest::CreateCacheFlag::kSpillToDisk;
bool generate_id =
(flag & BaseRequest::CreateCacheFlag::kGenerateRowId) == BaseRequest::CreateCacheFlag::kGenerateRowId;
(flag & CreateCacheRequest::CreateCacheFlag::kGenerateRowId) == CreateCacheRequest::CreateCacheFlag::kGenerateRowId;
if (spill && top_.empty()) {
RETURN_STATUS_UNEXPECTED("Server is not set up with spill support.");
}
RETURN_UNEXPECTED_IF_NULL(out_cookie);
*out_cookie = "";
flatbuffers::FlatBufferBuilder fbb;
flatbuffers::Offset<flatbuffers::String> off_cookie;
// Before creating the cache, first check if this is a request for a shared usage of an existing cache
// If two CreateService come in with identical connection_id, we need to serialize the create.
// The first create will be successful and be given a special cookie.
UniqueLock lck(&rwLock_);
// Early exit if we are doing global shutdown
if (global_shutdown_) {
return Status::OK();
}
auto end = all_caches_.end();
auto it = all_caches_.find(connection_id);
bool duplicate = false;
if (it == end) {
std::unique_ptr<CacheService> cs;
try {
cs = std::make_unique<CacheService>(cache_mem_sz, spill ? top_ : "", generate_id);
RETURN_IF_NOT_OK(cs->ServiceStart());
*out_cookie = cs->cookie();
cookie = cs->cookie();
all_caches_.emplace(connection_id, std::move(cs));
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory);
}
} else {
duplicate = true;
MS_LOG(INFO) << "Duplicate request for " + std::to_string(connection_id) + " to create cache service";
// We can return OK but we will return a duplicate key so user can act accordingly to either ignore it
// treat it as OK.
return Status(StatusCode::kDuplicateKey);
}
off_cookie = fbb.CreateString(cookie);
CreateCacheReplyMsgBuilder bld(fbb);
bld.add_connection_id(connection_id);
bld.add_cookie(off_cookie);
auto off = bld.Finish();
fbb.Finish(off);
reply->set_result(fbb.GetBufferPointer(), fbb.GetSize());
// Track the history of all the sessions that we have created so far.
history_sessions_.insert(session_id);
// We can return OK but we will return a duplicate key so user can act accordingly to either ignore it
// treat it as OK.
return duplicate ? Status(StatusCode::kDuplicateKey) : Status::OK();
}

Status CacheServer::DestroyCache(CacheService *cs, CacheRequest *rq) {
// We need a strong lock to protect the map.
UniqueLock lck(&rwLock_);
// it is already destroyed. Ignore it.
if (cs != nullptr) {
auto id = rq->connection_id();
MS_LOG(WARNING) << "Dropping cache with connection id " << std::to_string(id);
// std::map will invoke the destructor of CacheService. So we don't need to do anything here.
auto n = all_caches_.erase(id);
if (n == 0) {
// It has been destroyed by another duplicate request.
MS_LOG(INFO) << "Duplicate request for " + std::to_string(id) + " to create cache service";
}
}
return Status::OK();
}

inline Status CacheRow(CacheService *cs, CacheRequest *rq, CacheReply *reply) {
auto connection_id = rq->connection_id();
if (cs == nullptr) {
std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
auto sz = rq->buf_data_size();
std::vector<const void *> buffers;
buffers.reserve(sz);
// First piece of data is the cookie and is required
CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing cookie");
auto &cookie = rq->buf_data(0);
// Only if the cookie matches, we can accept insert into this cache that has a build phase
if (!cs->HasBuildPhase() || cookie == cs->cookie()) {
// Push the address of each buffer (in the form of std::string coming in from protobuf) into
// a vector of buffer
for (auto i = 1; i < sz; ++i) {
buffers.push_back(rq->buf_data(i).data());
}
row_id_type id = -1;
RETURN_IF_NOT_OK(cs->CacheRow(buffers, &id));
reply->set_result(std::to_string(id));
} else {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch");
}
}
return Status::OK();
}

/// This is the main loop the cache server thread(s) are running.
/// Each thread will pop a request and save the result in the same request.
/// The sender will wait on the wait post in the request. Once the request
/// is fulfilled, the server thread will do a post signalling the request is
/// is processed.
Status CacheServer::FastCacheRow(CacheService *cs, CacheRequest *rq, CacheReply *reply) {
auto connection_id = rq->connection_id();
auto shared_pool = comm_layer_->GetSharedMemoryPool();
auto *base = shared_pool->SharedMemoryBaseAddr();
// Ensure we got 3 pieces of data coming in
CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data_size() == 3, "Incomplete data");
// First piece of data is the cookie and is required
auto &cookie = rq->buf_data(0);
// Second piece of data is the address where we can find the serialized data
auto addr = strtoll(rq->buf_data(1).data(), nullptr, 10);
auto p = reinterpret_cast<void *>(reinterpret_cast<int64_t>(base) + addr);
// Third piece of data is the size of the serialized data that we need to transfer
auto sz = strtoll(rq->buf_data(2).data(), nullptr, 10);
// Successful or not, we need to free the memory on exit.
Status rc;
if (cs == nullptr) {
std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found";
rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
// Only if the cookie matches, we can accept insert into this cache that has a build phase
if (!cs->HasBuildPhase() || cookie == cs->cookie()) {
row_id_type id = -1;
ReadableSlice src(p, sz);
rc = cs->FastCacheRow(src, &id);
reply->set_result(std::to_string(id));
} else {
rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch");
}
}
// Return the block to the shared memory.
shared_pool->Deallocate(p);
return rc;
}

Status CacheServer::BatchFetchRows(CacheService *cs, CacheRequest *rq, CacheReply *reply) {
auto connection_id = rq->connection_id();
if (cs == nullptr) {
std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing row id");
auto &row_id_buf = rq->buf_data(0);
auto p = flatbuffers::GetRoot<TensorRowIds>(row_id_buf.data());
std::vector<row_id_type> row_id;
auto sz = p->row_id()->size();
row_id.reserve(sz);
for (auto i = 0; i < sz; ++i) {
row_id.push_back(p->row_id()->Get(i));
}
int64_t mem_sz = 0;
std::vector<key_size_pair> v;
RETURN_IF_NOT_OK(cs->PreBatchFetch(row_id, &v, &mem_sz));
auto client_flag = rq->flag();
bool local_client = BitTest(client_flag, kLocalClientSupport);
// For large amount data to be sent back, we will use shared memory provided it is a local
// client that has local bypass support
bool local_bypass = local_client ? (mem_sz >= kLocalByPassThreshold) : false;
reply->set_flag(local_bypass ? kDataIsInSharedMemory : 0);
if (local_bypass) {
// We will use shared memory
auto shared_pool = comm_layer_->GetSharedMemoryPool();
auto *base = shared_pool->SharedMemoryBaseAddr();
void *q = nullptr;
RETURN_IF_NOT_OK(shared_pool->Allocate(mem_sz, &q));
WritableSlice dest(q, mem_sz);
RETURN_IF_NOT_OK(cs->BatchFetch(row_id, v, &dest));
// We can't return the absolute address which makes no sense to the client.
// Instead we return the difference.
auto difference = reinterpret_cast<int64_t>(q) - reinterpret_cast<int64_t>(base);
reply->set_result(std::to_string(difference));
} else {
// We are going to use std::string to allocate and hold the result which will be eventually
// 'moved' to the protobuf message (which underneath is also a std::string) for the purpose
// to minimize memory copy.
std::string mem;
try {
mem.resize(mem_sz);
CHECK_FAIL_RETURN_UNEXPECTED(mem.capacity() >= mem_sz, "Programming error");
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory);
}
WritableSlice dest(mem.data(), mem_sz);
RETURN_IF_NOT_OK(cs->BatchFetch(row_id, v, &dest));
reply->set_result(std::move(mem));
}
}
return Status::OK();
}

inline Status GetStat(CacheService *cs, CacheRequest *rq, CacheReply *reply) {
auto connection_id = rq->connection_id();
if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
CacheService::ServiceStat svc_stat;
RETURN_IF_NOT_OK(cs->GetStat(&svc_stat));
flatbuffers::FlatBufferBuilder fbb;
ServiceStatMsgBuilder bld(fbb);
bld.add_num_disk_cached(svc_stat.stat_.num_disk_cached);
bld.add_num_mem_cached(svc_stat.stat_.num_mem_cached);
bld.add_avg_cache_sz(svc_stat.stat_.average_cache_sz);
bld.add_max_row_id(svc_stat.max_);
bld.add_min_row_id(svc_stat.min_);
bld.add_state(svc_stat.state_);
auto offset = bld.Finish();
fbb.Finish(offset);
reply->set_result(fbb.GetBufferPointer(), fbb.GetSize());
}
return Status::OK();
}

inline Status CacheSchema(CacheService *cs, CacheRequest *rq) {
auto connection_id = rq->connection_id();
if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing schema information");
auto &create_schema_buf = rq->buf_data(0);
RETURN_IF_NOT_OK(cs->CacheSchema(create_schema_buf.data(), create_schema_buf.size()));
}
return Status::OK();
}

inline Status FetchSchema(CacheService *cs, CacheRequest *rq, CacheReply *reply) {
auto connection_id = rq->connection_id();
if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
// We are going to use std::string to allocate and hold the result which will be eventually
// 'moved' to the protobuf message (which underneath is also a std::string) for the purpose
// to minimize memory copy.
std::string mem;
RETURN_IF_NOT_OK(cs->FetchSchema(&mem));
reply->set_result(std::move(mem));
}
return Status::OK();
}

inline Status BuildPhaseDone(CacheService *cs, CacheRequest *rq) {
auto connection_id = rq->connection_id();
if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
// First piece of data is the cookie
CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing cookie");
auto &cookie = rq->buf_data(0);
// We can only allow to switch phase is the cookie match.
if (cookie == cs->cookie()) {
RETURN_IF_NOT_OK(cs->BuildPhaseDone());
} else {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch");
}
}
return Status::OK();
}

Status CacheServer::PurgeCache(CacheService *cs) {
SharedLock lck(&rwLock_);
// If shutdown in progress, ignore the command.
if (global_shutdown_) {
return Status::OK();
}
// it is already purged. Ignore it.
if (cs != nullptr) {
RETURN_IF_NOT_OK(cs->Purge());
}
return Status::OK();
}

inline Status GenerateClientSessionID(session_id_type session_id, CacheReply *reply) {
reply->set_result(std::to_string(session_id));
return Status::OK();
}

/// \brief This is the main loop the cache server thread(s) are running.
/// Each thread will pop a request and send the result back to the client using grpc
/// \return
Status CacheServer::ServerRequest() {
Status CacheServer::ServerRequest(int32_t worker_id) {
TaskManager::FindMe()->Post();
// Loop forever until we are interrupted.
while (true) {
BaseRequest *base_rq = nullptr;
RETURN_IF_NOT_OK(cache_q_->PopFront(&base_rq));
auto cs = GetService(base_rq->connection_id_);
auto &my_que = cache_q_->operator[](worker_id);
// Loop forever until we are interrupted or shutdown.
while (!global_shutdown_) {
CacheServerRequest *cache_req = nullptr;
RETURN_IF_NOT_OK(my_que->PopFront(&cache_req));
auto &rq = cache_req->rq_;
auto &reply = cache_req->reply_;
CacheService *cs = nullptr;
// Request comes in roughly two sets. One set is at the cache level with a connection id.
// The other set is working at a high level and without a connection id
if (!rq.has_connection_info()) {
cs = GetService(rq.connection_id());
}
// Except for creating a new session, we expect cs is not null.
switch (base_rq->type_) {
switch (cache_req->type_) {
case BaseRequest::RequestType::kCacheRow: {
if (cs == nullptr) {
std::string errMsg = "Cache id " + std::to_string(base_rq->connection_id_) + " not found";
base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
// Look into the flag to see where we can find the data and
// call the appropriate method.
auto flag = rq.flag();
if (BitTest(flag, kDataIsInSharedMemory)) {
cache_req->rc_ = FastCacheRow(cs, &rq, &reply);
} else {
auto *rq = reinterpret_cast<CacheRowRequest *>(base_rq);
// Only if the cookie matches, we can accept insert into this cache that has a build phase
if (!cs->HasBuildPhase() || rq->cookie_ == cs->cookie()) {
rq->rc_ = cs->CacheRow(rq->buffers_, &rq->row_id_from_server_);
} else {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch");
}
cache_req->rc_ = CacheRow(cs, &rq, &reply);
}
break;
}
case BaseRequest::RequestType::kBatchFetchRows: {
if (cs == nullptr) {
std::string errMsg = "Cache id " + std::to_string(base_rq->connection_id_) + " not found";
base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
auto *rq = reinterpret_cast<BatchFetchRequest *>(base_rq);
rq->rc_ = cs->BatchFetch(rq->row_id_, &rq->mem_);
}
cache_req->rc_ = BatchFetchRows(cs, &rq, &reply);
break;
}
case BaseRequest::RequestType::kCreateCache: {
// If the cache is already created we still need to run the creation so that we do sanity checks on the
// client id and return the cache id back to the user.
auto *rq = reinterpret_cast<CreationCacheRequest *>(base_rq);
rq->rc_ = CreateService(rq->connection_id_, rq->cache_mem_sz, rq->flag_, &rq->cookie_);
cache_req->rc_ = CreateService(&rq, &reply);
break;
}
case BaseRequest::RequestType::kPurgeCache: {
if (cs != nullptr) {
base_rq->rc_ = cs->Purge();
} else {
// it is already purged. Ignore it.
base_rq->rc_ = Status::OK();
}
cache_req->rc_ = PurgeCache(cs);
break;
}
case BaseRequest::RequestType::kDestroyCache: {
if (cs != nullptr) {
// We need a strong lock to protect the map.
connection_id_type id = base_rq->connection_id_;
UniqueLock lck(&rwLock_);
// std::map will invoke the constructor of CacheService. So we don't need to do anything here.
auto n = all_caches_.erase(id);
if (n == 0) {
// It has been destroyed by another duplicate request.
MS_LOG(INFO) << "Duplicate request for " + std::to_string(id) + " to create cache service";
}
base_rq->rc_ = Status::OK();
} else {
// it is already destroyed. Ignore it.
base_rq->rc_ = Status::OK();
}
cache_req->rc_ = DestroyCache(cs, &rq);
break;
}
case BaseRequest::RequestType::kGetStat: {
if (cs == nullptr) {
std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found";
base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
auto *rq = reinterpret_cast<GetStatRequest *>(base_rq);
CacheService::ServiceStat svc_stat;
rq->rc_ = cs->GetStat(&svc_stat);
if (rq->rc_.IsOk()) {
flatbuffers::FlatBufferBuilder fbb;
ServiceStatMsgBuilder bld(fbb);
bld.add_num_disk_cached(svc_stat.stat_.num_disk_cached);
bld.add_num_mem_cached(svc_stat.stat_.num_mem_cached);
bld.add_max_row_id(svc_stat.max_);
bld.add_min_row_id(svc_stat.min_);
bld.add_state(svc_stat.state_);
auto offset = bld.Finish();
fbb.Finish(offset);
rq->rc_ = rq->mem_.allocate(fbb.GetSize());
if (rq->rc_.IsOk()) {
WritableSlice dest(rq->mem_.GetMutablePointer(), fbb.GetSize());
ReadableSlice src(fbb.GetBufferPointer(), fbb.GetSize());
RETURN_IF_NOT_OK(WritableSlice::Copy(&dest, src));
}
}
}
cache_req->rc_ = GetStat(cs, &rq, &reply);
break;
}
case BaseRequest::RequestType::kCacheSchema: {
if (cs == nullptr) {
std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found";
base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
auto *rq = reinterpret_cast<CacheSchemaRequest *>(base_rq);
rq->rc_ = cs->CacheSchema(rq->buf_, rq->len_of_buf_);
}
cache_req->rc_ = CacheSchema(cs, &rq);
break;
}
case BaseRequest::RequestType::kFetchSchema: {
if (cs == nullptr) {
std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found";
base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
auto *rq = reinterpret_cast<FetchSchemaRequest *>(base_rq);
rq->rc_ = cs->FetchSchema(&rq->mem_);
}
cache_req->rc_ = FetchSchema(cs, &rq, &reply);
break;
}
case BaseRequest::RequestType::kBuildPhaseDone: {
if (cs == nullptr) {
std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found";
base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
auto *rq = reinterpret_cast<BuildPhaseDoneRequest *>(base_rq);
// We can only allow to switch phase is the cookie match.
if (rq->cookie_ == cs->cookie()) {
rq->rc_ = cs->BuildPhaseDone();
} else {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch");
}
}
cache_req->rc_ = BuildPhaseDone(cs, &rq);
break;
}
case BaseRequest::RequestType::kDropSession: {
cache_req->rc_ = DestroySession(&rq);
break;
}
case BaseRequest::RequestType::kGenerateSessionId: {
cache_req->rc_ = GenerateClientSessionID(GenerateSessionID(), &reply);
break;
}
case BaseRequest::RequestType::kAllocateSharedBlock: {
cache_req->rc_ = AllocateSharedMemory(&rq, &reply);
break;
}
case BaseRequest::RequestType::kFreeSharedBlock: {
cache_req->rc_ = FreeSharedMemory(&rq);
break;
}
case BaseRequest::RequestType::kStopService: {
// This command shutdowns everything.
cache_req->rc_ = GlobalShutdown();
break;
}
default:
base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Unknown request type");
std::string errMsg("Unknown request type : ");
errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
}
// Notify it is done, and move on to the next request.
base_rq->wp_.Set();
Status2CacheReply(cache_req->rc_, &reply);
cache_req->st_ = CacheServerRequest::STATE::FINISH;
// We will re-tag the request back to the grpc queue. Once it comes back from the client,
// the CacheServerRequest, i.e. the pointer cache_req, will be free
if (!global_shutdown_) {
cache_req->responder_.Finish(reply, grpc::Status::OK, cache_req);
}
}
return Status::OK();
}

connection_id_type CacheServer::GetConnectionID(session_id_type session_id, uint32_t crc) const {
connection_id_type connection_id =
(static_cast<connection_id_type>(session_id) << 32u) | static_cast<connection_id_type>(crc);
return connection_id;
}

session_id_type CacheServer::GetSessionID(connection_id_type connection_id) const {
return static_cast<session_id_type>(connection_id >> 32u);
}

CacheServer::CacheServer(const std::string &spill_path, int32_t num_workers, int32_t port,
int32_t shared_meory_sz_in_gb)
: top_(spill_path),
num_workers_(num_workers),
port_(port),
shared_memory_sz_in_gb_(shared_meory_sz_in_gb),
global_shutdown_(false) {}

Status CacheServer::Run() {
RETURN_IF_NOT_OK(ServiceStart());
// This is called by the main function and we shouldn't exit. Otherwise the main thread
// will just shutdown. So we will call some function that never return unless error.
// One good case will be simply to wait for all threads to return.
RETURN_IF_NOT_OK(vg_.join_all(Task::WaitFlag::kBlocking));
return Status::OK();
}

Status CacheServer::GetFreeRequestTag(int32_t queue_id, CacheServerRequest **q) {
RETURN_UNEXPECTED_IF_NULL(q);
CacheServer &cs = CacheServer::GetInstance();
CacheServerRequest *p;
RETURN_IF_NOT_OK(cs.free_list_->operator[](queue_id)->PopFront(&p));
*q = p;
return Status::OK();
}

Status CacheServer::ReturnRequestTag(CacheServerRequest *p) {
RETURN_UNEXPECTED_IF_NULL(p);
int32_t myQID = p->getQid();
// Free any memory from the protobufs
p->~CacheServerRequest();
// Re-initialize the memory
new (p) CacheServerRequest(myQID);
// Now we return it back to free list.
CacheServer &cs = CacheServer::GetInstance();
RETURN_IF_NOT_OK(cs.free_list_->operator[](myQID)->Add(p));
return Status::OK();
}

Status CacheServer::DestroySession(CacheRequest *rq) {
CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing session id");
auto drop_session_id = rq->connection_info().session_id();
UniqueLock lck(&rwLock_);
for (auto &cs : all_caches_) {
auto connection_id = cs.first;
auto session_id = GetSessionID(connection_id);
// We can just call DestroyCache() but we are holding a lock already. Doing so will cause deadlock.
// So we will just manually do it.
if (session_id == drop_session_id) {
// std::map will invoke the destructor of CacheService. So we don't need to do anything here.
auto n = all_caches_.erase(connection_id);
MS_LOG(INFO) << "Destroy " << n << " copies of cache with id " << connection_id;
}
}
return Status::OK();
}

session_id_type CacheServer::GenerateSessionID() const {
SharedLock lock(&rwLock_);
auto mt = GetRandomDevice();
std::uniform_int_distribution<session_id_type> distribution(0, std::numeric_limits<session_id_type>::max());
session_id_type session_id;
bool duplicate = false;
do {
session_id = distribution(mt);
auto it = history_sessions_.find(session_id);
duplicate = (it != history_sessions_.end());
} while (duplicate);
return session_id;
}

Status CacheServer::AllocateSharedMemory(CacheRequest *rq, CacheReply *reply) {
auto requestedSz = strtoll(rq->buf_data(0).data(), nullptr, 10);
auto shared_pool = comm_layer_->GetSharedMemoryPool();
auto *base = shared_pool->SharedMemoryBaseAddr();
void *p = nullptr;
RETURN_IF_NOT_OK(shared_pool->Allocate(requestedSz, &p));
// We can't return the absolute address which makes no sense to the client.
// Instead we return the difference.
auto difference = reinterpret_cast<int64_t>(p) - reinterpret_cast<int64_t>(base);
reply->set_result(std::to_string(difference));
return Status::OK();
}

Status CacheServer::FreeSharedMemory(CacheRequest *rq) {
auto shared_pool = comm_layer_->GetSharedMemoryPool();
auto *base = shared_pool->SharedMemoryBaseAddr();
auto addr = strtoll(rq->buf_data(0).data(), nullptr, 10);
auto p = reinterpret_cast<void *>(reinterpret_cast<int64_t>(base) + addr);
shared_pool->Deallocate(p);
return Status::OK();
}

Status CacheServer::RpcRequest(int32_t worker_id) {
TaskManager::FindMe()->Post();
RETURN_IF_NOT_OK(comm_layer_->HandleRequest(worker_id));
return Status::OK();
}

Status CacheServer::GlobalShutdown() {
// Let's shutdown in proper order.
bool expected = false;
if (global_shutdown_.compare_exchange_strong(expected, true)) {
MS_LOG(WARNING) << "Shutting down server.";
// Shutdown the grpc queue. No longer accept any new comer.
// The threads we spawn to work on the grpc queue will exit themselves once
// they notice the queue has been shutdown.
comm_layer_->Shutdown();
// Now we interrupt any threads that are waiting on cache_q_
vg_.interrupt_all();
// The next thing to do drop all the caches.
UniqueLock lck(&rwLock_);
for (auto &it : all_caches_) {
auto id = it.first;
MS_LOG(WARNING) << "Dropping cache with connection id " << std::to_string(id);
// Wait for all outstanding work to be finished.
auto &cs = it.second;
UniqueLock cs_lock(&cs->rw_lock_);
// std::map will invoke the destructor of CacheService. So we don't need to do anything here.
(void)all_caches_.erase(id);
}
}
return Status::OK();
}

Status CacheServer::Builder::SanityCheck() {
if (shared_memory_sz_in_gb_ <= 0) {
RETURN_STATUS_UNEXPECTED("Shared memory size (in GB unit) must be positive");
}
if (num_workers_ <= 0) {
RETURN_STATUS_UNEXPECTED("Number of parallel workers must be positive");
}
if (!top_.empty()) {
auto p = top_.data();
if (p[0] != '/') {
RETURN_STATUS_UNEXPECTED("Spilling directory must be an absolute path");
}
// Check if the spill directory is writable
Path spill(top_);
auto t = spill / Services::GetUniqueID();
Status rc = t.CreateDirectory();
if (rc.IsOk()) {
rc = t.Remove();
}
if (rc.IsError()) {
RETURN_STATUS_UNEXPECTED("Spilling directory is not writable\n" + rc.ToString());
}
}
return Status::OK();
}
CacheServer::CacheServer(const std::string &spill_path, int32_t num_workers)
: top_(spill_path), num_workers_(num_workers) {}
} // namespace dataset
} // namespace mindspore

+ 154
- 14
mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h View File

@@ -24,8 +24,11 @@
#include <utility>
#include <vector>
#include <map>
#include <set>
#include "minddata/dataset/engine/cache/cache_service.h"
#include "minddata/dataset/engine/cache/cache_grpc_server.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/util/allocator.h"
#include "minddata/dataset/util/arena.h"
#include "minddata/dataset/util/cache_pool.h"
#include "minddata/dataset/util/lock.h"
@@ -37,43 +40,131 @@

namespace mindspore {
namespace dataset {
class BaseRequest;
/// \brief A server which provides CacheService services.
class CacheServer : public Service {
public:
friend class Services;
using cache_index = std::map<connection_id_type, std::unique_ptr<CacheService>>;
class Builder {
public:
Builder() : top_("/tmp"), num_workers_(32), port_(50052), shared_memory_sz_in_gb_(4) {}

/// \brief Getter functions
const std::string &getTop() const { return top_; }
int32_t getNumWorkers() const { return num_workers_; }
int32_t getPort() const { return port_; }
int32_t getSharedMemorySzInGb() const { return shared_memory_sz_in_gb_; }

Builder &SetRootDirectory(std::string root) {
top_ = std::move(root);
return *this;
}
Builder &SetNumWorkers(int32_t n) {
num_workers_ = n;
return *this;
}
Builder &SetPort(int32_t p) {
port_ = p;
return *this;
}
Builder &SetSharedMemorySizeInGB(int32_t sz) {
shared_memory_sz_in_gb_ = sz;
return *this;
}

Status SanityCheck();

void Print(std::ostream &out) const {
out << "Summary of the cache server configuration\n"
<< "Spill directory: " << getTop() << "\n"
<< "Number of parallel workers: " << getNumWorkers() << "\n"
<< "Tcp/ip port: " << getPort() << "\n"
<< "Shared memory size (in GB): " << getSharedMemorySzInGb();
}

friend std::ostream &operator<<(std::ostream &out, const Builder &bld) {
bld.Print(out);
return out;
}

Status Build() {
RETURN_IF_NOT_OK(SanityCheck());
// We need to bring up the Task Manager by bringing up the Services singleton.
RETURN_IF_NOT_OK(Services::CreateInstance());
RETURN_IF_NOT_OK(CacheServer::CreateInstance(top_, num_workers_, port_, shared_memory_sz_in_gb_));
return Status::OK();
}

private:
std::string top_;
int32_t num_workers_;
int32_t port_;
int32_t shared_memory_sz_in_gb_;
};
CacheServer(const CacheServer &) = delete;
CacheServer &operator=(const CacheServer &) = delete;
CacheServer(CacheServer &&) = delete;
CacheServer &operator=(CacheServer &) = delete;
static CacheServer &GetInstance() noexcept { return Services::getCacheServer(); }
Status DoServiceStart() override;
Status DoServiceStop() override;
~CacheServer() { (void)ServiceStop(); }

static Status CreateInstance(const std::string &spill_path, int32_t num_workers, int32_t port,
int32_t shared_memory_sz) {
std::call_once(init_instance_flag_, [&]() -> Status {
auto &svcManager = Services::GetInstance();
RETURN_IF_NOT_OK(svcManager.AddHook(&instance_, spill_path, num_workers, port, shared_memory_sz));
return Status::OK();
});
return Status::OK();
}

static CacheServer &GetInstance() { return *instance_; }

/// \brief For the current demonstration, a cache client contacts cache server using a Queue.
/// \param rq
/// \return Status object
Status PushRequest(BaseRequest *rq) {
Status PushRequest(int32_t queue_id, CacheServerRequest *rq) {
RETURN_UNEXPECTED_IF_NULL(rq);
RETURN_IF_NOT_OK(cache_q_->Add(rq));
RETURN_IF_NOT_OK(cache_q_->operator[](queue_id)->Add(rq));
return Status::OK();
}

/// \\brief Kick off server threads. Never return unless error out.
Status Run();

/// \brief Get a free tag
/// \param q[in] pointer to a pointer to a CacheServerRequest
/// \return Status object
static Status GetFreeRequestTag(int32_t queue_id, CacheServerRequest **q);

/// \brief Return a tag to the free list
/// \param p[in] pointer to already finished CacheServerRequest tag
/// \return Status object
static Status ReturnRequestTag(CacheServerRequest *p);

private:
static std::once_flag init_instance_flag_;
static CacheServer *instance_;
mutable RWLock rwLock_;
std::string top_;
cache_index all_caches_;
std::shared_ptr<Queue<BaseRequest *>> cache_q_;
std::set<session_id_type> history_sessions_;
std::shared_ptr<QueueList<CacheServerRequest *>> cache_q_;
std::shared_ptr<QueueList<CacheServerRequest *>> free_list_;
std::vector<std::unique_ptr<MemGuard<CacheServerRequest, Allocator<CacheServerRequest>>>> tag_;
std::shared_ptr<CacheServerGreeterImpl> comm_layer_;
std::shared_ptr<MemoryPool> mp_;
TaskGroup vg_;
int32_t num_workers_;
int32_t port_;
int32_t shared_memory_sz_in_gb_;
std::atomic<bool> global_shutdown_;

/// \brief Constructor
/// \param spill_path Top directory for spilling buffers to.
/// \param num_workers Number of threads for handling requests.
explicit CacheServer(const std::string &spill_path, int32_t num_workers = 3);
explicit CacheServer(const std::string &spill_path, int32_t num_workers, int32_t port, int32_t share_memory_sz_in_gb);

/// \brief Locate a cache service from connection id.
/// \return Pointer to cache service. Null if not found
@@ -82,16 +173,65 @@ class CacheServer : public Service {
/// \brief Create a cache service. We allow multiple clients to create the same cache service.
/// Subsequent duplicate requests are ignored. The first cache client to create the service will be given
/// a special unique cookie.
/// \param[in] connection_id This is from a Cache client.
/// \param[in] cache_mem_sz
/// \param[in] flag
/// \param[out] out_cookie Only the first cache client will be given a special cookie to identify the creator
/// \return Status object
Status CreateService(connection_id_type connection_id, uint64_t cache_mem_sz, BaseRequest::CreateCacheFlag flag,
std::string *out_cookie);
Status CreateService(CacheRequest *rq, CacheReply *reply);

/// \brief Destroy a cache service
/// \param cs
/// \param rq
/// \return
Status DestroyCache(CacheService *cs, CacheRequest *rq);
Status PurgeCache(CacheService *cs);

/// \brief Entry point for all internal server threads.
Status ServerRequest(int32_t worker_id);

/// \brief Entry point for all grpc threads.
/// \return
Status RpcRequest(int32_t worker_id);

Status DestroySession(CacheRequest *rq);

/// \brief Create a connection id from a session id and a crc
/// \param session_id
/// \param crc
/// \return connection id
connection_id_type GetConnectionID(session_id_type session_id, uint32_t crc) const;

/// \brief Extract the session id from a connection id
/// \param connection_id
/// \return session id
session_id_type GetSessionID(connection_id_type connection_id) const;

/// \brief Generate a session ID for the client
/// \return Session ID
session_id_type GenerateSessionID() const;

/// \brief Handle kAllocateSharedBlock request
/// \param rq CacheRequest
/// \param reply CacheReply
/// \return Status object
Status AllocateSharedMemory(CacheRequest *rq, CacheReply *reply);

/// \brief Handle kFreeSharedBlock request
/// \param rq
/// \return Status object
Status FreeSharedMemory(CacheRequest *rq);

/// \brief Entry point for all server threads.
Status ServerRequest();
/// \brief Handle kFastCacheRow request
/// \return Status object
Status FastCacheRow(CacheService *cs, CacheRequest *rq, CacheReply *reply);

/// \brief Internal function to do row batch fetch
/// \param cs CacheService
/// \param rq Request
/// \param reply Reply
/// \return
Status BatchFetchRows(CacheService *cs, CacheRequest *rq, CacheReply *reply);

/// \brief A proper shutdown of the server
/// \return Status object
Status GlobalShutdown();
};
} // namespace dataset
} // namespace mindspore


+ 79
- 30
mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc View File

@@ -76,7 +76,7 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type
*row_id_generated = GetNextRowId();
// Some debug information on how many rows we have generated so far.
if ((*row_id_generated) % 1000 == 0) {
MS_LOG(DEBUG) << "Number of rows cached: " << *row_id_generated;
MS_LOG(DEBUG) << "Number of rows cached: " << (*row_id_generated) + 1;
}
} else {
if (msg->row_id() < 0) {
@@ -114,6 +114,45 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type
RETURN_STATUS_UNEXPECTED(e.what());
}
}

Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_generated) {
SharedLock rw(&rw_lock_);
RETURN_UNEXPECTED_IF_NULL(row_id_generated);
if (st_ == State::kFetchPhase) {
// For this kind of cache service, once we are done with the build phase into fetch phase, we can't
// allow other to cache more rows.
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
}
try {
// If we don't need to generate id, we need to find it from the buffer.
if (generate_id_) {
*row_id_generated = GetNextRowId();
// Some debug information on how many rows we have generated so far.
if ((*row_id_generated) % 1000 == 0) {
MS_LOG(DEBUG) << "Number of rows cached: " << (*row_id_generated) + 1;
}
} else {
auto msg = GetTensorRowHeaderMsg(src.GetPointer());
if (msg->row_id() < 0) {
std::string errMsg = "Expect positive row id: " + std::to_string(msg->row_id());
RETURN_STATUS_UNEXPECTED(errMsg);
}
*row_id_generated = msg->row_id();
}
// Now we cache the flat buffer.
CachePool::key_type key;
RETURN_IF_NOT_OK(cp_->Insert({src}, &key));
Status rc = map_->DoInsert(*row_id_generated, key);
if (rc == Status(StatusCode::kDuplicateKey)) {
MS_LOG(DEBUG) << "Ignoring duplicate key.";
} else {
RETURN_IF_NOT_OK(rc);
}
return Status::OK();
} catch (const std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
}
}
std::ostream &operator<<(std::ostream &out, const CacheService &cs) {
// Then show any custom derived-internal stuff
out << "\nCache memory size: " << cs.cache_mem_sz_;
@@ -155,20 +194,15 @@ Status CacheService::GetStat(CacheService::ServiceStat *out) {
}
return Status::OK();
}
Status CacheService::BatchFetch(const std::vector<row_id_type> &v, MemGuard<uint8_t> *out) const {
RETURN_UNEXPECTED_IF_NULL(out);

Status CacheService::PreBatchFetch(const std::vector<row_id_type> &v, std::vector<key_size_pair> *out,
int64_t *mem_sz) {
SharedLock rw(&rw_lock_);
if (st_ == State::kBuildPhase) {
// For this kind of cache service, we can't fetch yet until we are done with caching all the rows.
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
}
RETURN_UNEXPECTED_IF_NULL(out);
RETURN_UNEXPECTED_IF_NULL(mem_sz);
const auto num_elements = v.size();
int64_t mem_sz = (num_elements + 1) * sizeof(int64_t);
int64_t data_offset = mem_sz;
std::vector<int64_t> sz_v;
std::vector<CachePool::key_type> keys;
sz_v.reserve(num_elements);
keys.reserve(num_elements);
*mem_sz = (num_elements + 1) * sizeof(int64_t);
(*out).reserve(num_elements);
for (auto row_id : v) {
auto r = map_->Search(row_id);
if (r.second) {
@@ -180,25 +214,33 @@ Status CacheService::BatchFetch(const std::vector<row_id_type> &v, MemGuard<uint
errMsg += std::to_string(key);
RETURN_STATUS_UNEXPECTED(errMsg);
}
keys.push_back(key);
sz_v.push_back(sz);
mem_sz += sz;
(*out).emplace_back(key, sz);
(*mem_sz) += sz;
} else {
keys.push_back(-1);
sz_v.push_back(0);
(*out).emplace_back(-1, 0);
}
}
MemGuard<uint8_t> mem;
RETURN_IF_NOT_OK(mem.allocate(mem_sz));
auto *offset_array = reinterpret_cast<int64_t *>(mem.GetMutablePointer());
return Status::OK();
}

Status CacheService::BatchFetch(const std::vector<row_id_type> &v, const std::vector<key_size_pair> &info,
WritableSlice *out) const {
RETURN_UNEXPECTED_IF_NULL(out);
SharedLock rw(&rw_lock_);
if (st_ == State::kBuildPhase) {
// For this kind of cache service, we can't fetch yet until we are done with caching all the rows.
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
}
const auto num_elements = v.size();
int64_t data_offset = (num_elements + 1) * sizeof(int64_t);
auto *offset_array = reinterpret_cast<int64_t *>(out->GetMutablePointer());
offset_array[0] = data_offset;
WritableSlice all(mem.GetMutablePointer(), mem.GetSizeInBytes());
for (auto i = 0; i < num_elements; ++i) {
auto sz = sz_v.at(i);
auto sz = info.at(i).second;
offset_array[i + 1] = offset_array[i] + sz;
if (sz > 0) {
WritableSlice row_data(all, offset_array[i], sz);
auto key = keys.at(i);
WritableSlice row_data(*out, offset_array[i], sz);
auto key = info.at(i).first;
size_t bytesRead = 0;
RETURN_IF_NOT_OK(cp_->Read(key, &row_data, &bytesRead));
if (bytesRead != sz) {
@@ -208,7 +250,6 @@ Status CacheService::BatchFetch(const std::vector<row_id_type> &v, MemGuard<uint
}
}
}
*out = std::move(mem);
return Status::OK();
}
Status CacheService::CacheSchema(const void *buf, int64_t len) {
@@ -232,18 +273,26 @@ Status CacheService::CacheSchema(const void *buf, int64_t len) {
}
return Status::OK();
}
Status CacheService::FetchSchema(MemGuard<uint8_t> *out) const {
Status CacheService::FetchSchema(std::string *out) const {
SharedLock rw(&rw_lock_);
if (st_ == State::kBuildPhase) {
// For this kind of cache service, we can't fetch yet until we are done with caching all the rows.
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
}
RETURN_UNEXPECTED_IF_NULL(out);
MemGuard<uint8_t> mem;
// We are going to use std::string to allocate and hold the result which will be eventually
// 'moved' to the protobuf message (which underneath is also a std::string) for the purpose
// to minimize memory copy.
std::string mem;
if (schema_key_ >= 0) {
auto len = cp_->GetSize(schema_key_);
RETURN_IF_NOT_OK(mem.allocate(len));
auto slice = WritableSlice(mem.GetMutablePointer(), len);
try {
mem.resize(len);
CHECK_FAIL_RETURN_UNEXPECTED(mem.capacity() >= len, "Programming error");
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory);
}
auto slice = WritableSlice(mem.data(), len);
RETURN_IF_NOT_OK(cp_->Read(schema_key_, &slice));
*out = std::move(mem);
} else {


+ 18
- 4
mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h View File

@@ -28,7 +28,6 @@
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/cache/cache_request.h"
#include "minddata/dataset/engine/cache/de_tensor_generated.h"
#include "minddata/dataset/util/arena.h"
#include "minddata/dataset/util/btree.h"
#include "minddata/dataset/util/cache_pool.h"
@@ -38,7 +37,8 @@

namespace mindspore {
namespace dataset {
struct CacheStat;
/// Some typedef used for BatchFetch
using key_size_pair = std::pair<CachePool::key_type, size_t>;
/// \brief A cache service for storing/fetching buffers to in memory cache and may spill to disk the cache service is
/// created to support spilling
class CacheService : public Service {
@@ -69,12 +69,26 @@ class CacheService : public Service {
/// \param[out] row_id_generated The row id assigned to this row if any
/// \return Status object
Status CacheRow(const std::vector<const void *> &buf, row_id_type *row_id_generated);

/// \brief A fast version of CacheRow where all the data is already in one contiguous piece.
/// \param src Slice of the data
/// \param row_id_generated
/// \return Status object
Status FastCacheRow(const ReadableSlice &src, row_id_type *row_id_generated);

/// \brief This function is used in preparation for batch fetching.
/// It calculates how much memory we should allocate and which row id are present.
/// \param[in/out] Pointer to vector of <CachePool::key_type, size_t>
/// \param[in/out] mem_sz how much memory is required to batch fetch
/// \return Status object
Status PreBatchFetch(const std::vector<row_id_type> &v, std::vector<key_size_pair> *, int64_t *mem_sz);

/// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded
/// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row.
/// \param[in] v A vector of row id.
/// \param[out] out A contiguous memory buffer that holds the requested rows.
/// \return Status object
Status BatchFetch(const std::vector<row_id_type> &v, MemGuard<uint8_t> *out) const;
Status BatchFetch(const std::vector<row_id_type> &v, const std::vector<key_size_pair> &, WritableSlice *out) const;

/// \brief Getter function
/// \return Spilling path
@@ -102,7 +116,7 @@ class CacheService : public Service {
/// \brief Fetch schema
/// \param out A contiguous memory that contains the serialized form of schema.
/// \return Status object
Status FetchSchema(MemGuard<uint8_t> *out) const;
Status FetchSchema(std::string *out) const;
/// \brief Purge the content of a cache
/// \return Status object
Status Purge();


+ 14
- 1
mindspore/ccsrc/minddata/dataset/engine/cache/de_tensor.fbs View File

@@ -60,10 +60,11 @@ table TensorRowIds {
}

/// Statistics returned from each cache service
/// \note It must match CacheService::ServiceStat
/// \note It must match CacheServiceStat
table ServiceStatMsg {
num_mem_cached:int64;
num_disk_cached:int64;
avg_cache_sz:int64;
min_row_id:int64;
max_row_id:int64;
state:int8;
@@ -79,3 +80,15 @@ table ColumnNameMsg {
table SchemaMsg {
column:[ColumnNameMsg];
}

/// Part of the CreateCacheRequest
table CreateCacheRequestMsg {
cache_mem_sz:int64;
flag:uint32;
}

/// Return result of CreateCacheRequest
table CreateCacheReplyMsg {
connection_id:int64;
cookie:string;
}

+ 45
- 0
mindspore/ccsrc/minddata/dataset/engine/cache/stub/cache_grpc_client.h View File

@@ -0,0 +1,45 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_STUB_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_STUB_H_

#include <memory>
#include <string>
#include "proto/cache_grpc.pb.h"
#include "minddata/dataset/engine/cache/cache_common.h"
#include "minddata/dataset/engine/cache/cache_request.h"
#include "minddata/dataset/util/service.h"

namespace mindspore {
namespace dataset {
class CacheClientGreeter : public Service {
public:
explicit CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_workers) {}
~CacheClientGreeter() override {}
Status DoServiceStart() override { RETURN_STATUS_UNEXPECTED("Not supported"); }
Status DoServiceStop() override { RETURN_STATUS_UNEXPECTED("Not supported"); }

void *SharedMemoryBaseAddr() { return nullptr; }
Status HandleRequest(std::shared_ptr<BaseRequest> rq) { RETURN_STATUS_UNEXPECTED("Not supported"); }
Status AttachToSharedMemory(int32_t port, bool *local_bypass) { RETURN_STATUS_UNEXPECTED("Not supported"); }

protected:
private:
};
} // namespace dataset
} // namespace mindspore

#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_STUB_H_

+ 119
- 20
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc View File

@@ -16,6 +16,7 @@
#include "minddata/dataset/engine/datasetops/cache_base_op.h"
#include <iomanip>
#include <iostream>
#include <utility>
#include "minddata/dataset/engine/execution_tree.h"

namespace mindspore {
@@ -47,22 +48,39 @@ Status CacheBase::Reset() {
}
CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler)
: ParallelOp(num_workers, op_connector_size, sampler),
cache_client_(cache_client),
: ParallelOp(num_workers, op_connector_size, std::move(sampler)),
row_cnt_(0),
num_cache_miss_(0),
cache_client_(std::move(cache_client)),
rows_per_buffer_(rows_per_buf),
// We can cause deadlock if this internal Connector size is too small.
keys_miss_(num_workers_, 1, connector_capacity_) {
keys_miss_(num_workers_, 1, connector_capacity_),
prefetch_size_(cache_client_->getPrefetchSize()) {
io_block_queues_.Init(num_workers, op_connector_size);
prefetch_queues_.Init(num_workers, op_connector_size);
sampler_queue_ = std::make_unique<Queue<std::shared_ptr<Tensor>>>(op_connector_size);
}
// Common function to fetch samples from the sampler and send them using the io_block_queues to
// the parallel workers
Status CacheBase::FetchSamplesToWorkers() {
int64_t buf_cnt = 0;
int64_t wait_cnt = 0;
// Kick off several threads which will prefetch prefetch_size_ rows in advance. The rows_per_buffers_
// is too small (1 by default) and won't help performance.
RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("Dispatcher", std::bind(&CacheBase::Dispatcher, this)));
RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CacheBase::Prefetcher, this, std::placeholders::_1)));
// Instead of sending sampler id to WorkerEntry, we send them to the Prefetcher which will redirect them
// to the WorkerEntry.
do {
epoch_sync_.Clear();
if (AllowCacheMiss() && wait_cnt > 0) {
MS_LOG(WARNING) << "Epoch: " << wait_cnt << " Cache Miss : " << num_cache_miss_
<< " Total number of rows : " << row_cnt_;
}
num_cache_miss_ = 0;
row_cnt_ = 0;
++wait_cnt;
std::vector<row_id_type> keys;
int64_t row_cnt = 0;
keys.reserve(rows_per_buffer_);
std::unique_ptr<DataBuffer> sampler_buffer;
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
@@ -70,10 +88,13 @@ Status CacheBase::FetchSamplesToWorkers() {
TensorRow sample_row;
RETURN_IF_NOT_OK(sampler_buffer->PopRow(&sample_row));
std::shared_ptr<Tensor> sample_ids = sample_row[0];
// Send the sampler tensor to other thread for prefetching. We are using shared pointer so it
// won't go out scope until it is really not in use.
RETURN_IF_NOT_OK(sampler_queue_->Add(sample_ids));
for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); itr++) {
keys.push_back(*itr);
++row_cnt;
if (row_cnt % rows_per_buffer_ == 0) {
++row_cnt_;
if (row_cnt_ % rows_per_buffer_ == 0) {
auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));
RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk)));
keys.clear();
@@ -90,7 +111,7 @@ Status CacheBase::FetchSamplesToWorkers() {
io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
// If repeat but the not last repeat, wait for reset.
if (!IsLastIteration()) {
MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << ++wait_cnt << " Buffer sent " << buf_cnt;
MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << wait_cnt << " Buffer sent " << buf_cnt;
RETURN_IF_NOT_OK(epoch_sync_.Wait());
} else {
// We can break out from the loop.
@@ -101,13 +122,21 @@ Status CacheBase::FetchSamplesToWorkers() {
// Flow the eof before exit
RETURN_IF_NOT_OK(
io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof)));
// Ask all the workers to quit.
// Shutdown threads
std::shared_ptr<Tensor> empty;
RETURN_IF_NOT_OK(sampler_queue_->Add(std::move(empty)));
for (int32_t i = 0; i < num_workers_; i++) {
RETURN_IF_NOT_OK(
io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone)));
}
// Dump the last epoch result (approximately) without waiting for the worker threads to come back.
if (AllowCacheMiss()) {
MS_LOG(WARNING) << "Epoch: " << wait_cnt << " Cache Miss : " << num_cache_miss_
<< " Total number of rows : " << row_cnt_;
}
return Status::OK();
}

Status CacheBase::FetchFromCache(int32_t worker_id) {
int64_t buffer_id = worker_id;
std::unique_ptr<IOBlock> blk;
@@ -133,23 +162,16 @@ Status CacheBase::FetchFromCache(int32_t worker_id) {
}
std::unique_ptr<DataBuffer> db = std::make_unique<DataBuffer>(buffer_id, DataBuffer::kDeBFlagNone);
std::unique_ptr<TensorQTable> que = std::make_unique<TensorQTable>();
TensorTable ttbl;
RETURN_IF_NOT_OK(cache_client_->GetRows(keys, &ttbl));
auto row_it = ttbl.begin();
std::vector<row_id_type> cache_miss;
cache_miss.reserve(keys.size());
for (auto row_id : keys) {
auto &row = *row_it;
TensorRow row;
// Block until the row shows up in the pool.
RETURN_IF_NOT_OK(prefetch_.PopFront(row_id, &row));
if (row.empty()) {
if (AllowCacheMiss()) {
cache_miss.push_back(row_id);
} else {
std::string errMsg = "Row id " + std::to_string(row_id) + " not found.";
RETURN_STATUS_UNEXPECTED(errMsg);
}
cache_miss.push_back(row_id);
}
que->push_back(std::move(row));
++row_it;
}
db->set_tensor_table(std::move(que));
if (AllowCacheMiss()) {
@@ -162,12 +184,17 @@ Status CacheBase::FetchFromCache(int32_t worker_id) {
} while (true);
return Status::OK();
}

Status CacheBase::RegisterResources() {
RETURN_IF_NOT_OK(epoch_sync_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(prefetch_queues_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(sampler_queue_->Register(tree_->AllTasks()));
return Status::OK();
}
CacheBase::~CacheBase() {}

CacheBase::~CacheBase() = default;

Status CacheBase::UpdateColumnMapFromCache() {
Status rc;
// Get the schema from the server. It may not be there yet. So tolerate the error.
@@ -180,5 +207,77 @@ Status CacheBase::UpdateColumnMapFromCache() {
}
return rc;
}

Status CacheBase::Dispatcher() {
TaskManager::FindMe()->Post();
int64_t buf_cnt = 0;
int64_t num_row = 0;
std::vector<row_id_type> keys;
keys.reserve(prefetch_size_);
do {
keys.clear();
std::shared_ptr<Tensor> sample_ids;
RETURN_IF_NOT_OK(sampler_queue_->PopFront(&sample_ids));
if (sample_ids == nullptr) {
// A null shared pointer signal times to quit.
// Also signal all prefetchers to quit.
for (int32_t i = 0; i < num_workers_; i++) {
RETURN_IF_NOT_OK(
prefetch_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone)));
}
break;
}
// Now we distribute the sampler ids to each prefetcher according to the prefetch size.
for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); itr++) {
keys.push_back(*itr);
++num_row;
if (num_row % prefetch_size_ == 0) {
auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));
RETURN_IF_NOT_OK(prefetch_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk)));
keys.clear();
}
}
// Send the remaining sample id
if (!keys.empty()) {
auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));
RETURN_IF_NOT_OK(prefetch_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk)));
}
} while (true);
return Status::OK();
}

Status CacheBase::Prefetcher(int32_t worker_id) {
TaskManager::FindMe()->Post();
std::vector<row_id_type> prefetch_keys;
prefetch_keys.reserve(prefetch_size_);
do {
prefetch_keys.clear();
std::unique_ptr<IOBlock> blk;
RETURN_IF_NOT_OK(prefetch_queues_[worker_id]->PopFront(&blk));
RETURN_IF_NOT_OK(blk->GetKeys(&prefetch_keys));
if (prefetch_keys.empty()) {
// Empty keys mean time to quit.
break;
}
TensorTable ttbl;
RETURN_IF_NOT_OK(cache_client_->GetRows(prefetch_keys, &ttbl));
auto row_it = ttbl.begin();
for (auto row_id : prefetch_keys) {
auto &row = *row_it;
if (row.empty()) {
if (AllowCacheMiss()) {
++num_cache_miss_;
} else {
std::string errMsg = "Row id " + std::to_string(row_id) + " not found.";
RETURN_STATUS_UNEXPECTED(errMsg);
}
}
// Put the prefetch row into the pool and wake up any WorkerEntry to wait for the row
RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row)));
++row_it;
}
} while (true);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 15
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h View File

@@ -16,6 +16,8 @@
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_

#include <atomic>
#include <deque>
#include <memory>
#include <string>
#include <utility>
@@ -28,8 +30,9 @@
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/util/queue.h"
#include "minddata/dataset/util/queue_map.h"
#include "minddata/dataset/util/semaphore.h"
#include "minddata/dataset/util/wait_post.h"
#include "minddata/dataset/engine/datasetops/cache_base_op.h"
namespace mindspore {
namespace dataset {
/// \brief This is the base class for CacheOp and CacheLookupOp which share many similarities.
@@ -82,10 +85,13 @@ class CacheBase : public ParallelOp {

protected:
constexpr static int32_t eoe_row_id = -1;
int64_t row_cnt_;
std::atomic<int64_t> num_cache_miss_;
std::shared_ptr<CacheClient> cache_client_;
WaitPost epoch_sync_;
int32_t rows_per_buffer_;
Connector<std::vector<row_id_type>> keys_miss_;
QueueMap<row_id_type, TensorRow> prefetch_;

/// \brief Common function to register resources for interrupt
/// \note Derived should override this function for extra resources to be registered
@@ -103,7 +109,15 @@ class CacheBase : public ParallelOp {

private:
constexpr static int32_t connector_capacity_ = 1024;
int32_t prefetch_size_;
QueueList<std::unique_ptr<IOBlock>> io_block_queues_;
QueueList<std::unique_ptr<IOBlock>> prefetch_queues_;
std::unique_ptr<Queue<std::shared_ptr<Tensor>>> sampler_queue_;

Status Dispatcher();
/// \brief Prefetcher. It prefetch the rows from cache server
/// \return Status object.
Status Prefetcher(int32_t worker_id);
};
} // namespace dataset
} // namespace mindspore


+ 95
- 80
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc View File

@@ -16,8 +16,10 @@
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"

#include <algorithm>
#include <chrono>
#include <functional>
#include <iomanip>
#include <utility>
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/core/global_context.h"
@@ -41,9 +43,13 @@ void CacheMergeOp::Print(std::ostream &out, bool show_all) const {
out << "\n\n";
}
}

CacheMergeOp::CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners,
std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<Sampler> &sampler)
: ParallelOp(numWorkers, opConnectorSize, sampler), num_cleaners_(numCleaners), cache_client_(cache_client) {}
: ParallelOp(numWorkers, opConnectorSize, sampler),
num_cleaners_(numCleaners),
cache_client_(std::move(cache_client)) {}

Status CacheMergeOp::operator()() {
// A queue of row id to let cleaner send cache miss rows to the cache server
// We don't want a small queue as this will block the parallel op workers.
@@ -62,6 +68,7 @@ Status CacheMergeOp::operator()() {
TaskManager::FindMe()->Post();
return Status::OK();
}

// Each parallel worker will pop from the CacheHit stream. If there is a missing TensorRow, we will wait
// until it shows up in the pool.
Status CacheMergeOp::WorkerEntry(int32_t worker_id) {
@@ -82,10 +89,8 @@ Status CacheMergeOp::WorkerEntry(int32_t worker_id) {
RETURN_IF_NOT_OK(db_ptr->PopRow(&row));
if (row.empty()) {
auto row_id = row.getId();
TensorRowRequest *rq = nullptr;
RETURN_IF_NOT_OK(GetRq(row_id, &rq));
// Block until the row shows up in the pool.
RETURN_IF_NOT_OK(rq->Wait(&row));
RETURN_IF_NOT_OK(cache_miss_.PopFront(row_id, &row));
}
tbl->push_back(std::move(row));
}
@@ -97,6 +102,7 @@ Status CacheMergeOp::WorkerEntry(int32_t worker_id) {
RETURN_IF_NOT_OK(EofReceived(worker_id));
return Status::OK();
}

Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) {
TaskManager::FindMe()->Post();
// We will simply pop TensorRow from the stream and insert them into the pool and
@@ -123,17 +129,27 @@ Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) {
std::string errMsg = "Expect positive row id: " + std::to_string(row_id);
RETURN_STATUS_UNEXPECTED(errMsg);
}
TensorRowRequest *rq = nullptr;
// Technically number of this row shows up in the cache miss stream is equal to the number
// of P() call. However the cleaner wants it too. So we need an extra copy.
TensorRowCacheRequest *rq;
RETURN_IF_NOT_OK(GetRq(row_id, &rq));
rq->WakeUpAny(std::move(row));
// Let the cleaner to flush out this row (async) to the cache server.
RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id));
if (rq->GetState() == TensorRowCacheRequest::State::kEmpty) {
// We will send the request async. But any error we most
// likely ignore and continue.
Status rc;
rc = rq->AsyncSendCacheRequest(cache_client_, row);
if (rc.IsOk()) {
RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id));
}
}
RETURN_IF_NOT_OK(cache_miss_.Add(row_id, std::move(row)));
}
}
RETURN_IF_NOT_OK(cache_missing_stream->GetNextBuffer(&db_ptr, workerId));
}
return Status::OK();
}

Status CacheMergeOp::Cleaner() {
TaskManager::FindMe()->Post();
while (true) {
@@ -142,45 +158,28 @@ Status CacheMergeOp::Cleaner() {
if (row_id < 0) {
break;
}
TensorRowRequest *rq = nullptr;
// Locate the cache request
TensorRowCacheRequest *rq;
RETURN_IF_NOT_OK(GetRq(row_id, &rq));
if (rq->GetState() == TensorRowRequest::State::kClean) {
// If already flushed, move on to the next one.
// If already flushed, move on to the next one.
if (rq->GetState() == TensorRowCacheRequest::State::kClean) {
continue;
}
TensorRow row;
RETURN_IF_NOT_OK(rq->Release(&row));
CHECK_FAIL_RETURN_UNEXPECTED(!row.empty(), "Programming error.");
Status rc = cache_client_->WriteRow(row);
// Bad rc should not bring down the pipeline
Status rc = rq->CheckCacheResult();
if (rc.IsError()) {
MS_LOG(WARNING) << "Cache not successful." << rc.ToString();
// If interrupt, time to quit.
if (rc.get_code() == StatusCode::kInterrupted) {
return Status::OK();
}
MS_LOG(INFO) << "Cache row not successful: " << rc.ToString();
// Bad rc should not bring down the pipeline. We will simply continue and
// change the state back to empty. We don't need a CAS from CLEAN back to EMPTY.
rq->SetState(TensorRowCacheRequest::State::kEmpty);
}
rq->SetState(TensorRowRequest::State::kClean);
}
return Status::OK();
}

Status CacheMergeOp::GetRq(row_id_type row_id, CacheMergeOp::TensorRowRequest **out) {
RETURN_UNEXPECTED_IF_NULL(out);
std::unique_lock<std::mutex> lck(mux_);
auto it = cache_miss_map_.find(row_id);
if (it != cache_miss_map_.end()) {
*out = it->second.GetMutablePointer();
} else {
// We will create a new one.
auto alloc = Services::GetAllocator<TensorRowRequest>();
auto r = cache_miss_map_.emplace(row_id, MemGuard<TensorRowRequest, Allocator<TensorRowRequest>>(alloc));
if (r.second) {
auto &mem = r.first->second;
RETURN_IF_NOT_OK(mem.allocate(1, row_id));
*out = mem.GetMutablePointer();
} else {
RETURN_STATUS_UNEXPECTED("Map insert fail.");
}
}
return Status::OK();
}
Status CacheMergeOp::PrepareNodePostAction() { // Run any common code from super class first before adding our own
// specific logic
CHECK_FAIL_RETURN_UNEXPECTED(child_.size() == 2, "Incorrect number of children");
@@ -199,6 +198,7 @@ Status CacheMergeOp::PrepareNodePostAction() { // Run any common code from supe
RETURN_IF_NOT_OK(rc);
return Status::OK();
}

Status CacheMergeOp::ComputeColMap() {
CHECK_FAIL_RETURN_UNEXPECTED(child_[kCacheMissChildIdx] != nullptr, "Cache miss stream empty");
if (column_name_id_map().empty()) {
@@ -207,53 +207,13 @@ Status CacheMergeOp::ComputeColMap() {
CHECK_FAIL_RETURN_UNEXPECTED(!column_name_id_map().empty(), "No column map detected");
return Status::OK();
}
Status CacheMergeOp::TensorRowRequest::Wait(TensorRow *out) {
RETURN_UNEXPECTED_IF_NULL(out);
// Block until the missing row is in the pool.
RETURN_IF_NOT_OK(use_count_.P());
std::unique_lock<std::mutex> lck(dq_mux_);
CHECK_FAIL_RETURN_UNEXPECTED(!row_.empty(), "Programming error");
*out = std::move(row_.front());
row_.pop_front();
return Status::OK();
}
void CacheMergeOp::TensorRowRequest::WakeUpAny(TensorRow &&row) {
std::unique_lock<std::mutex> lck(dq_mux_);
// Technically number of this row shows up in the cache miss stream is equal to the number
// of P() call. However the cleaner wants it too. So we need an extra copy.
if (GetState() == State::kEmpty) {
// We will do a deep copy
for (auto &ts : row) {
std::shared_ptr<Tensor> out_ts;
Tensor::CreateFromTensor(ts, &out_ts);
cleaner_copy_.push_back(out_ts);
}
cleaner_copy_.setId(row.getId());
// Change the state to dirty
SetState(State::kDirty);
}
row_.push_back(std::move(row));
// Bump up the use count by 1. This wake up any parallel worker which is waiting
// for this row.
use_count_.V();
}
Status CacheMergeOp::TensorRowRequest::Release(TensorRow *out) {
RETURN_UNEXPECTED_IF_NULL(out);
// We are not holding any mutex here because the cleaner isn't really touching the deque row_.
// In case we have multiple cleaners and they all see the copy, only one of them will
// get it.
auto expected = State::kDirty;
if (st_.compare_exchange_strong(expected, State::kClean)) {
*out = std::move(cleaner_copy_);
}
return Status::OK();
}

// Builder constructor. Creates the builder object.
CacheMergeOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
build_num_workers_ = cfg->num_parallel_workers();
build_op_connector_size_ = cfg->op_connector_size();
build_num_cleaners_ = 1;
build_num_cleaners_ = cfg->num_parallel_workers();
}

// Check if the required parameters are set by the builder.
@@ -311,5 +271,60 @@ Status CacheMergeOp::EofReceived(int32_t worker_id) {
MS_LOG(DEBUG) << "Cache merge sending eof";
return DatasetOp::EofReceived(worker_id);
}

Status CacheMergeOp::GetRq(row_id_type row_id, CacheMergeOp::TensorRowCacheRequest **out) {
RETURN_UNEXPECTED_IF_NULL(out);
std::unique_lock<std::mutex> lock(mux_);
auto it = io_request_.find(row_id);
if (it != io_request_.end()) {
*out = it->second.GetMutablePointer();
} else {
// We will create a new one.
auto alloc = Services::GetAllocator<TensorRowCacheRequest>();
auto r = io_request_.emplace(row_id, MemGuard<TensorRowCacheRequest, Allocator<TensorRowCacheRequest>>(alloc));
if (r.second) {
auto &mem = r.first->second;
RETURN_IF_NOT_OK(mem.allocate(1));
*out = mem.GetMutablePointer();
} else {
RETURN_STATUS_UNEXPECTED("Map insert fail.");
}
}
return Status::OK();
}

Status CacheMergeOp::TensorRowCacheRequest::AsyncSendCacheRequest(const std::shared_ptr<CacheClient> &cc,
const TensorRow &row) {
auto expected = State::kEmpty;
if (st_.compare_exchange_strong(expected, State::kDirty)) {
// We will do a deep copy but write directly into CacheRequest protobuf or shared memory
Status rc;
cleaner_copy_ =
std::make_shared<CacheRowRequest>(cc->server_connection_id_, cc->cookie(), cc->SupportLocalClient());
rc = cleaner_copy_->SerializeCacheRowRequest(cc.get(), row);
if (rc.IsOk()) {
// Send the request async. The cleaner will check the return code.
rc = cc->PushRequest(cleaner_copy_);
}
if (rc.IsError()) {
// Clean up the shared pointer and reset the state back to empty
cleaner_copy_.reset();
st_ = State::kEmpty;
}
}
return Status::OK();
}

Status CacheMergeOp::TensorRowCacheRequest::CheckCacheResult() {
auto expected = State::kDirty;
if (st_.compare_exchange_strong(expected, State::kClean)) {
// Success or not, we will release the memory.
// We simply move it out of the structure and let it go out of scope.
auto cache_request = std::move(cleaner_copy_);
RETURN_IF_NOT_OK(cache_request->Wait());
return Status::OK();
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 30
- 14
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h View File

@@ -16,6 +16,7 @@
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_

#include <algorithm>
#include <atomic>
#include <deque>
#include <map>
@@ -28,6 +29,7 @@
#include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/dataset_iterator.h"
#include "minddata/dataset/util/queue.h"
#include "minddata/dataset/util/queue_map.h"
#include "minddata/dataset/util/semaphore.h"

namespace mindspore {
@@ -36,28 +38,34 @@ namespace dataset {
/// stream
class CacheMergeOp : public ParallelOp {
public:
// Some handshake structures among the main thread, cleaner threads and parallel op threads.
class TensorRowRequest {
// Some handshake structures between CacheMissWorkerEntry and Cleaner
class TensorRowCacheRequest {
public:
enum class State : uint8_t {
kEmpty = 0, // No row in the deque
kEmpty = 0, // Initial state. Row hasn't arrived from cache miss stream yet.
kDirty = 1, // Cleaner hasn't flushed it to the cache server yet.
kClean = 2 // The row has been flushed already.
};
explicit TensorRowRequest(row_id_type id) : st_(State::kEmpty), use_count_(0) {}
~TensorRowRequest() = default;
TensorRowCacheRequest() : st_(State::kEmpty) {}
~TensorRowCacheRequest() = default;
/// Getter and Setter of the state
State GetState() const { return st_; }
void SetState(State newState) { st_ = newState; }
Status Wait(TensorRow *out);
void WakeUpAny(TensorRow &&row);
Status Release(TensorRow *out);
/// Take a tensor row and send rpc call to the server async
/// \param cc Cache client of the CacheMergeOp
/// \param row TensorRow to be sent to the server
/// \return Status object
/// \note Thread safe
Status AsyncSendCacheRequest(const std::shared_ptr<CacheClient> &cc, const TensorRow &row);

/// \brief We send the row to the server async so the CacheMissWorkerEntry can continue.
/// It is the cleaner that will check the result.
/// \return Status object
Status CheckCacheResult();

private:
std::mutex dq_mux_;
std::atomic<State> st_;
Semaphore use_count_;
std::deque<TensorRow> row_;
TensorRow cleaner_copy_;
std::shared_ptr<CacheRowRequest> cleaner_copy_;
};

constexpr static int kCacheHitChildIdx = 0; // Cache hit stream
@@ -80,6 +88,8 @@ class CacheMergeOp : public ParallelOp {
/// \return Builder setter method returns reference to the builder.
Builder &SetNumWorkers(int32_t num_workers) {
build_num_workers_ = num_workers;
// Adjust the number of cleaners to match the number of workers
build_num_cleaners_ = std::max(build_num_cleaners_, build_num_workers_);
return *this;
}

@@ -159,7 +169,6 @@ class CacheMergeOp : public ParallelOp {
/// \param workerId
/// \return Status object
Status CacheMissWorkerEntry(int32_t workerId);
Status GetRq(row_id_type row_id, TensorRowRequest **);

/// \brief Base-class override for NodePass pre-visit acceptor
/// \param[in] p The node to visit
@@ -188,11 +197,18 @@ class CacheMergeOp : public ParallelOp {

private:
std::mutex mux_;
std::map<row_id_type, MemGuard<TensorRowRequest, Allocator<TensorRowRequest>>> cache_miss_map_;
QueueMap<row_id_type, TensorRow> cache_miss_;
std::map<row_id_type, MemGuard<TensorRowCacheRequest, Allocator<TensorRowCacheRequest>>> io_request_;
std::unique_ptr<Queue<row_id_type>> io_que_;
std::shared_ptr<CacheClient> cache_client_;
int32_t num_cleaners_;

/// \brief Locate the cache request from the io_request_ map
/// \param row_id
/// \param out pointer to the cache request
/// \return Status object
Status GetRq(row_id_type row_id, TensorRowCacheRequest **out);

/// \brief These are the entry functions for the cleaner threads. Each cleaner is responsible for
/// moving cache miss TensorRow into the CacheServer.
/// \return Status object


+ 2
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc View File

@@ -142,7 +142,7 @@ Status CacheOp::WaitForCachingAllRows() {
}
// Get statistics from the server, and if we are not the one to create the cache,
// wait until the state changed from build phase to fetch base.
CacheClient::ServiceStat stat{};
CacheServiceStat stat{};
bool BuildPhaseDone = true;
do {
RETURN_IF_NOT_OK(cache_client_->GetStat(&stat));
@@ -157,6 +157,7 @@ Status CacheOp::WaitForCachingAllRows() {
MS_LOG(INFO) << "Number of rows cached: " << num_rows_;
MS_LOG(INFO) << "Number of rows cached in memory : " << stat.num_mem_cached;
MS_LOG(INFO) << "Number of rows spilled to disk : " << stat.num_disk_cached;
MS_LOG(INFO) << "Average cache size : " << stat.avg_cache_sz;
// Now all rows are cached and we have done a sync point check up. Next phase is
// is pick up fetch input from sampler and pass up to the caller.
RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this));


+ 7
- 0
mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc View File

@@ -392,6 +392,13 @@ uint32_t DatasetOp::GenerateCRC(const std::shared_ptr<DatasetOp> &op) {
ss_str = std::regex_replace(ss_str, std::regex("Num workers.*\n"), "");
ss_str = std::regex_replace(ss_str, std::regex("\\[workers.*\\]"), "");

// Filter out tcp/ip information
ss_str = std::regex_replace(ss_str, std::regex("Hostname.*\n"), "");
ss_str = std::regex_replace(ss_str, std::regex("Port.*\n"), "");
ss_str = std::regex_replace(ss_str, std::regex("Number of rpc workers.*\n"), "");
ss_str = std::regex_replace(ss_str, std::regex("Prefetch size.*\n"), "");
ss_str = std::regex_replace(ss_str, std::regex("Local client support.*\n"), "");

// Filter out Number of rows when generating the check sum
ss_str = std::regex_replace(ss_str, std::regex("Number of rows.*\n"), "");



+ 1
- 0
mindspore/ccsrc/minddata/dataset/include/status.h View File

@@ -73,6 +73,7 @@ enum class StatusCode : char {
kProfilingError = 10,
kBoundingBoxOutOfBounds = 11,
kBoundingBoxInvalidShape = 12,
kSyntaxError = 13,
// Make this error code the last one. Add new error code above it.
kUnexpectedError = 127
};


+ 1
- 1
mindspore/ccsrc/minddata/dataset/util/allocator.h View File

@@ -168,9 +168,9 @@ class MemGuard {
size_t GetSizeInBytes() const { return n_ * sizeof(T); }

private:
size_t n_;
allocator alloc_;
std::unique_ptr<T[]> ptr_;
size_t n_;
};
} // namespace dataset
} // namespace mindspore


+ 15
- 15
mindspore/ccsrc/minddata/dataset/util/arena.h View File

@@ -27,20 +27,20 @@
#define ARENA_WALL_OVERHEAD_SZ 32
namespace mindspore {
namespace dataset {
// This is a memory arena based on a treap data structure.
// The constructor of the Arena takes the size of the initial memory size (in MB).
// Internally we divide the memory into multiple blocks. Each block is 64 bytes.
// The treap contains all the free blocks with the relative memory address as key
// and the size of the block as priority.
//
// Initially the treap has only one root which is the whole memory piece.
//
// For memory suballocation, we pop the root node of the treap which contains the largest free block.
// We allocate what we need and return the rest back to the treap. We search for the first fit instead
// of the best fit so to give us a constant time in memory allocation.
//
// When a block of memory is freed. It is joined with the blocks before and after (if they are available) to
// form a bigger block.
/// This is a memory arena based on a treap data structure.
/// The constructor of the Arena takes the size of the initial memory size (in MB).
/// Internally we divide the memory into multiple blocks. Each block is 64 bytes.
/// The treap contains all the free blocks with the relative memory address as key
/// and the size of the block as priority.
///
/// Initially the treap has only one root which is the whole memory piece.
///
/// For memory suballocation, we pop the root node of the treap which contains the largest free block.
/// We allocate what we need and return the rest back to the treap. We search for the first fit instead
/// of the best fit so to give us a constant time in memory allocation.
///
/// When a block of memory is freed. It is joined with the blocks before and after (if they are available) to
/// form a bigger block.
class Arena : public MemoryPool {
public:
Arena(const Arena &) = delete;
@@ -78,7 +78,7 @@ class Arena : public MemoryPool {

static Status CreateArena(std::shared_ptr<Arena> *p_ba, size_t val_in_MB = 4096);

private:
protected:
std::mutex mux_;
Treap<uint64_t, uint64_t> tr_;
void *ptr_;


+ 9
- 0
mindspore/ccsrc/minddata/dataset/util/cache_pool.cc View File

@@ -140,13 +140,22 @@ Path CachePool::GetSpillPath() const {
}
CachePool::CacheStat CachePool::GetStat() const {
CacheStat cs{0};
int64_t total_sz = 0;
for (auto &it : *tree_) {
total_sz += it.sz;
if (it.ptr != nullptr) {
++cs.num_mem_cached;
} else {
++cs.num_disk_cached;
}
}
if (total_sz > 0) {
// integer arithmetic. NO need to cast to float or double.
cs.average_cache_sz = total_sz / (cs.num_disk_cached + cs.num_mem_cached);
if (cs.average_cache_sz == 0) {
cs.average_cache_sz = 1;
}
}
return cs;
}
Status CachePool::Spill(CachePool::DataLocator *dl) {


+ 1
- 0
mindspore/ccsrc/minddata/dataset/util/cache_pool.h View File

@@ -82,6 +82,7 @@ class CachePool : public Service {
struct CacheStat {
int64_t num_mem_cached;
int64_t num_disk_cached;
int64_t average_cache_sz;
};

/// \brief Constructor


+ 127
- 0
mindspore/ccsrc/minddata/dataset/util/queue_map.h View File

@@ -0,0 +1,127 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_MAP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_MAP_H_

#include <deque>
#include <map>
#include <memory>
#include <mutex>
#include "minddata/dataset/util/allocator.h"
#include "minddata/dataset/util/semaphore.h"
#include "minddata/dataset/util/services.h"
namespace mindspore {
namespace dataset {
template <typename K, typename T>
/// \brief QueueMap is like a Queue but instead of there is a map of deque<T>.
/// Consumer will block if the corresponding deque is empty.
/// Producer can add an element of type T with key of type K to the map and
/// wake up any waiting consumer.
/// \tparam K key type
/// \tparam T payload of the map
class QueueMap {
public:
using key_type = K;
using value_type = T;

QueueMap() = default;
virtual ~QueueMap() = default;

/// Add an element <key, T> to the map and wake up any consumer that is waiting
/// \param key
/// \param payload
/// \return Status object
virtual Status Add(key_type key, T &&payload) {
RequestQueue *rq = nullptr;
RETURN_IF_NOT_OK(GetRq(key, &rq));
RETURN_IF_NOT_OK(rq->WakeUpAny(std::move(payload)));
return Status::OK();
}

/// Pop the front of the deque with key. Block if the deque is empty.
virtual Status PopFront(key_type key, T *out) {
RequestQueue *rq = nullptr;
RETURN_IF_NOT_OK(GetRq(key, &rq));
RETURN_IF_NOT_OK(rq->Wait(out));
return Status::OK();
}

protected:
/// This is a handshake structure between producer and consumer
class RequestQueue {
public:
RequestQueue() : use_count_(0) {}
~RequestQueue() = default;

Status Wait(T *out) {
RETURN_UNEXPECTED_IF_NULL(out);
// Block until the missing row is in the pool.
RETURN_IF_NOT_OK(use_count_.P());
std::unique_lock<std::mutex> lck(dq_mux_);
CHECK_FAIL_RETURN_UNEXPECTED(!row_.empty(), "Programming error");
*out = std::move(row_.front());
row_.pop_front();
return Status::OK();
}

Status WakeUpAny(T &&row) {
std::unique_lock<std::mutex> lck(dq_mux_);
row_.push_back(std::move(row));
// Bump up the use count by 1. This wake up any parallel worker which is waiting
// for this row.
use_count_.V();
return Status::OK();
}

private:
std::mutex dq_mux_;
Semaphore use_count_;
std::deque<T> row_;
};

/// Create or locate an element with matching key
/// \param key
/// \param out
/// \return Status object
Status GetRq(key_type key, RequestQueue **out) {
RETURN_UNEXPECTED_IF_NULL(out);
std::unique_lock<std::mutex> lck(mux_);
auto it = all_.find(key);
if (it != all_.end()) {
*out = it->second.GetMutablePointer();
} else {
// We will create a new one.
auto alloc = Services::GetAllocator<RequestQueue>();
auto r = all_.emplace(key, MemGuard<RequestQueue, Allocator<RequestQueue>>(alloc));
if (r.second) {
auto &mem = r.first->second;
RETURN_IF_NOT_OK(mem.allocate(1));
*out = mem.GetMutablePointer();
} else {
RETURN_STATUS_UNEXPECTED("Map insert fail.");
}
}
return Status::OK();
}

private:
std::mutex mux_;
std::map<K, MemGuard<RequestQueue, Allocator<RequestQueue>>> all_;
};
} // namespace dataset
} // namespace mindspore

#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_MAP_H_

+ 11
- 43
mindspore/ccsrc/minddata/dataset/util/services.cc View File

@@ -22,7 +22,6 @@
#include <stdlib.h>
#endif
#include <unistd.h>
#include "minddata/dataset/engine/cache/cache_server.h"
#include "minddata/dataset/util/circular_pool.h"
#include "minddata/dataset/util/random.h"
#include "minddata/dataset/util/task_manager.h"
@@ -59,35 +58,15 @@ std::string Services::GetUniqueID() {
return std::string(buffer, UNIQUEID_LEN);
}

TaskManager &Services::getTaskMgrInstance() {
Services &sm = GetInstance();
return *(static_cast<TaskManager *>(sm.sa_[kSlotTaskMgr_]));
}

CacheServer &Services::getCacheServer() {
Services &sm = GetInstance();
return *(static_cast<CacheServer *>(sm.sa_[kSlotCacheMgr_]));
}

Status Services::CreateAllInstances() {
// In order, TaskMgr, BufferMgr
Status rc;
sa_[kSlotTaskMgr_] = new (&rc, pool_) TaskManager();
RETURN_IF_NOT_OK(rc);
rc = sa_[kSlotTaskMgr_]->ServiceStart();
RETURN_IF_NOT_OK(rc);
// TODO(jesse) : Get the parameters from config file. Right now spill to /tmp and spawn 3 workers
#if !defined(_WIN32) && !defined(_WIN64)
sa_[kSlotCacheMgr_] = new (&rc, pool_) CacheServer("/tmp", 3);
RETURN_IF_NOT_OK(rc);
rc = sa_[kSlotCacheMgr_]->ServiceStart();
#else
sa_[kSlotCacheMgr_] = nullptr;
#endif
return rc;
// First one is always the TaskManager
RETURN_IF_NOT_OK(TaskManager::CreateInstance());
TaskManager &tm = TaskManager::GetInstance();
RETURN_IF_NOT_OK(tm.ServiceStart());
return Status::OK();
}

Services::Services() : pool_(nullptr), sa_{nullptr} {
Services::Services() : pool_(nullptr) {
Status rc = CircularPool::CreateCircularPool(&pool_, -1, 16, true); // each arena 16M
if (rc.IsError()) {
std::terminate();
@@ -95,22 +74,11 @@ Services::Services() : pool_(nullptr), sa_{nullptr} {
}

Services::~Services() noexcept {
try {
// In reverse order
CacheServer *cs = static_cast<CacheServer *>(sa_[kSlotCacheMgr_]);
if (cs != nullptr) {
(void)cs->ServiceStop();
cs->~CacheServer();
pool_->Deallocate(cs);
}
TaskManager *tm = static_cast<TaskManager *>(sa_[kSlotTaskMgr_]);
if (tm != nullptr) {
(void)tm->ServiceStop();
tm->~TaskManager();
pool_->Deallocate(tm);
}
} catch (const std::exception &e) {
// Do nothing.
// Shutdown in reverse order.
auto n = hook_.size();
while (n > 0) {
hook_.pop_back();
n = hook_.size();
}
}
} // namespace dataset


+ 20
- 12
mindspore/ccsrc/minddata/dataset/util/services.h View File

@@ -16,9 +16,11 @@
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_SERVICES_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_SERVICES_H_

#include <algorithm>
#include <memory>
#include <mutex>
#include <string>
#include <vector>
#include "minddata/dataset/util/memory_pool.h"
#include "minddata/dataset/util/allocator.h"
#include "minddata/dataset/util/service.h"
@@ -27,7 +29,7 @@
namespace mindspore {
namespace dataset {
class TaskManager;
class CacheServer;
class Services {
public:
static Status CreateInstance() {
@@ -59,10 +61,6 @@ class Services {

~Services() noexcept;

static TaskManager &getTaskMgrInstance();

static CacheServer &getCacheServer();

std::shared_ptr<MemoryPool> GetServiceMemPool() { return pool_; }

#if !defined(_WIN32) && !defined(_WIN64)
@@ -80,19 +78,29 @@ class Services {
return Allocator<T>(Services::GetInstance().GetServiceMemPool());
}

/// \brief Add a new service to the start up list.
/// \tparam T Class that implements Service
/// \return Status object and where the service is located in the hook_ list
template <typename T, typename... Args>
Status AddHook(T **out, Args &&... args) {
RETURN_UNEXPECTED_IF_NULL(out);
try {
(*out) = new T(std::forward<Args>(args)...);
std::unique_ptr<T> svc(*out);
hook_.push_back(std::move(svc));
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory);
}
return Status::OK();
}

private:
static std::once_flag init_instance_flag_;
static std::unique_ptr<Services> instance_;
// A small pool used for small objects that last until the
// Services Manager shuts down. Used by all sub-services.
std::shared_ptr<MemoryPool> pool_;
// We use pointers here instead of unique_ptr because we
// want to have ultimate control on the order of
// construction and destruction.
static constexpr int kSlotTaskMgr_ = 0;
static constexpr int kSlotCacheMgr_ = 1;
static constexpr int kNumServices_ = 2;
Service *sa_[kNumServices_];
std::vector<std::unique_ptr<Service>> hook_;

Services();



+ 1
- 0
mindspore/ccsrc/minddata/dataset/util/slice.h View File

@@ -86,6 +86,7 @@ class ReadableSlice {
class WritableSlice : public ReadableSlice {
public:
friend class StorageContainer;
friend class CacheService;
/// \brief Default constructor
WritableSlice() : ReadableSlice(), mutable_data_(nullptr) {}
/// \brief This form of a constructor takes a pointer and its size.


+ 3
- 0
mindspore/ccsrc/minddata/dataset/util/status.cc View File

@@ -48,6 +48,9 @@ std::string CodeAsString(const StatusCode c) {
case StatusCode::kProfilingError:
s = "Error encountered while profiling";
break;
case StatusCode::kSyntaxError:
s = "Syntax error";
break;
case StatusCode::kUnexpectedError:
default:
s = "Unexpected error";


+ 1
- 0
mindspore/ccsrc/minddata/dataset/util/status.h View File

@@ -80,6 +80,7 @@ enum class StatusCode : char {
kProfilingError = 10,
kBoundingBoxOutOfBounds = 11,
kBoundingBoxInvalidShape = 12,
kSyntaxError = 13,
// Make this error code the last one. Add new error code above it.
kUnexpectedError = 127
};


+ 2
- 0
mindspore/ccsrc/minddata/dataset/util/task_manager.cc View File

@@ -21,6 +21,8 @@

namespace mindspore {
namespace dataset {
TaskManager *TaskManager::instance_ = nullptr;
std::once_flag TaskManager::init_instance_flag_;
// This takes the same parameter as Task constructor.
Status TaskManager::CreateAsyncTask(const std::string &my_name, const std::function<Status()> &f, TaskGroup *vg,
Task **task) {


+ 12
- 1
mindspore/ccsrc/minddata/dataset/util/task_manager.h View File

@@ -54,7 +54,16 @@ class TaskManager : public Service {

TaskManager &operator=(const TaskManager &) = delete;

static TaskManager &GetInstance() noexcept { return Services::getTaskMgrInstance(); }
static Status CreateInstance() {
std::call_once(init_instance_flag_, [&]() -> Status {
auto &svcManager = Services::GetInstance();
RETURN_IF_NOT_OK(svcManager.AddHook(&instance_));
return Status::OK();
});
return Status::OK();
}

static TaskManager &GetInstance() noexcept { return *instance_; }

Status DoServiceStart() override;

@@ -96,6 +105,8 @@ class TaskManager : public Service {
Status WatchDog();

private:
static std::once_flag init_instance_flag_;
static TaskManager *instance_;
RWLock lru_lock_;
SpinLock free_lock_;
SpinLock tg_lock_;


+ 11
- 2
mindspore/dataset/engine/cache_client.py View File

@@ -25,15 +25,22 @@ class DatasetCache:
A client to interface with tensor caching service
"""

def __init__(self, session_id=None, size=0, spilling=False):
def __init__(self, session_id=None, size=0, spilling=False, port=50052, prefetch_size=20):
check_uint32(session_id, "session_id")
check_uint64(size, "size")
type_check(spilling, (bool,), "spilling")
check_uint32(port, "port")
check_uint32(prefetch_size, "prefetch size")

self.session_id = session_id
self.size = size
self.spilling = spilling
self.cache_client = CacheClient(session_id, size, spilling)
self.port = port
self.prefetch_size = prefetch_size
self.cache_client = CacheClient(session_id, size, spilling, port, prefetch_size)

def GetStat(self):
return self.cache_client.GetStat()

def __deepcopy__(self, memodict):
if id(self) in memodict:
@@ -44,5 +51,7 @@ class DatasetCache:
new_cache.session_id = copy.deepcopy(self.session_id, memodict)
new_cache.spilling = copy.deepcopy(self.spilling, memodict)
new_cache.size = copy.deepcopy(self.size, memodict)
new_cache.port = copy.deepcopy(self.port, memodict)
new_cache.prefetch_size = copy.deepcopy(self.prefetch_size, memodict)
new_cache.cache_client = self.cache_client
return new_cache

+ 131
- 100
tests/ut/cpp/dataset/cache_op_test.cc View File

@@ -43,13 +43,18 @@ class MindDataTestCacheOp : public UT::DatasetOpTesting {
}
};

TEST_F(MindDataTestCacheOp, TestCacheServer) {
TEST_F(MindDataTestCacheOp, DISABLED_TestCacheServer) {
Status rc;
CacheClient myClient(1, 0, true); // use arbitrary session of 1, size of 0, spilling is true
CacheClient::Builder builder;
// use arbitrary session of 1, size of 0, spilling// is true
builder.SetSessionId(1).SetCacheMemSz(0).SetSpill(true);
std::shared_ptr<CacheClient> myClient;
rc = builder.Build(&myClient);
ASSERT_TRUE(rc.IsOk());
// cksum value of 1 for CreateCache here...normally you do not directly create a cache and the cksum arg is generated.
rc = myClient.CreateCache(1, true);
EXPECT_TRUE(rc.IsOk());
std::cout << myClient << std::endl;
rc = myClient->CreateCache(1, true);
ASSERT_TRUE(rc.IsOk());
std::cout << *myClient << std::endl;

// Create a schema using the C api's
int32_t rank = 0; // not used
@@ -68,11 +73,11 @@ TEST_F(MindDataTestCacheOp, TestCacheServer) {

std::unordered_map<std::string, int32_t> map;
rc = testSchema->GetColumnNameMap(&map);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());

// Test the CacheSchema api
rc = myClient.CacheSchema(map);
EXPECT_TRUE(rc.IsOk());
rc = myClient->CacheSchema(map);
ASSERT_TRUE(rc.IsOk());

// Create a tensor, take a snapshot and restore it back, and compare.
std::shared_ptr<Tensor> t;
@@ -88,48 +93,54 @@ TEST_F(MindDataTestCacheOp, TestCacheServer) {
TensorRow row;
row.push_back(t);
int64_t row_id;
rc = myClient.WriteRow(row, &row_id);
EXPECT_TRUE(rc.IsOk());
rc = myClient->WriteRow(row, &row_id);
ASSERT_TRUE(rc.IsOk());

// Switch off build phase.
rc = myClient.BuildPhaseDone();
EXPECT_TRUE(rc.IsOk());
rc = myClient->BuildPhaseDone();
ASSERT_TRUE(rc.IsOk());

// Now restore from cache.
row.clear();
rc = myClient.GetRows({row_id}, &tbl);
rc = myClient->GetRows({row_id}, &tbl);
row = tbl.front();
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());
auto r = row.front();
std::cout << *r << std::endl;
// Compare
bool cmp = (*t == *r);
EXPECT_TRUE(cmp);
ASSERT_TRUE(cmp);

// Get back the schema and verify
std::unordered_map<std::string, int32_t> map_out;
rc = myClient.FetchSchema(&map_out);
EXPECT_TRUE(rc.IsOk());
rc = myClient->FetchSchema(&map_out);
ASSERT_TRUE(rc.IsOk());
cmp = (map_out == map);
EXPECT_TRUE(cmp);
ASSERT_TRUE(cmp);

// Test Purge and Destroy
rc = myClient.PurgeCache();
EXPECT_TRUE(rc.IsOk());
rc = myClient.DestroyCache();
EXPECT_TRUE(rc.IsOk());
rc = myClient->PurgeCache();
ASSERT_TRUE(rc.IsOk());
rc = myClient->DestroyCache();
ASSERT_TRUE(rc.IsOk());
}

TEST_F(MindDataTestCacheOp, TestConcurrencyRequest) {
TEST_F(MindDataTestCacheOp, DISABLED_TestConcurrencyRequest) {
// Clear the rc of the master thread if any
(void)TaskManager::GetMasterThreadRc();
TaskGroup vg;
Status rc;
CacheClient myClient(1, 1, true); // use arbitrary session of 1, size 1, spilling is true
// use arbitrary session of 1, size 1, spilling is true
CacheClient::Builder builder;
// use arbitrary session of 1, size of 0, spilling// is true
builder.SetSessionId(1).SetCacheMemSz(1).SetSpill(true);
std::shared_ptr<CacheClient> myClient;
rc = builder.Build(&myClient);
ASSERT_TRUE(rc.IsOk());
// cksum value of 1 for CreateCache here...normally you do not directly create a cache and the cksum arg is generated.
rc = myClient.CreateCache(1, true);
EXPECT_TRUE(rc.IsOk());
std::cout << myClient << std::endl;
rc = myClient->CreateCache(1, true);
ASSERT_TRUE(rc.IsOk());
std::cout << *myClient << std::endl;
std::shared_ptr<Tensor> t;
Tensor::CreateEmpty(TensorShape({2, 3}), DataType(DataType::DE_UINT64), &t);
t->SetItemAt<uint64_t>({0, 0}, 1);
@@ -146,19 +157,19 @@ TEST_F(MindDataTestCacheOp, TestConcurrencyRequest) {
Status vg_rc = vg.CreateAsyncTask("Test agent", [&myClient, &row]() -> Status {
TaskManager::FindMe()->Post();
for (auto i = 0; i < 500; i++) {
RETURN_IF_NOT_OK(myClient.WriteRow(row));
RETURN_IF_NOT_OK(myClient->WriteRow(row));
}
return Status::OK();
});
EXPECT_TRUE(vg_rc.IsOk());
ASSERT_TRUE(vg_rc.IsOk());
}
ASSERT_TRUE(vg.join_all().IsOk());
ASSERT_TRUE(vg.GetTaskErrorIfAny().IsOk());
rc = myClient.BuildPhaseDone();
rc = myClient->BuildPhaseDone();
ASSERT_TRUE(rc.IsOk());
// Get statistics from the server.
CacheClient::ServiceStat stat{};
rc = myClient.GetStat(&stat);
CacheServiceStat stat{};
rc = myClient->GetStat(&stat);
ASSERT_TRUE(rc.IsOk());
std::cout << stat.min_row_id << ":" << stat.max_row_id << ":" << stat.num_mem_cached << ":" << stat.num_disk_cached
<< "\n";
@@ -168,15 +179,15 @@ TEST_F(MindDataTestCacheOp, TestConcurrencyRequest) {
for (auto i = stat.min_row_id; i <= stat.max_row_id; ++i) {
tbl.clear();
row.clear();
rc = myClient.GetRows({i}, &tbl);
EXPECT_TRUE(rc.IsOk());
rc = myClient->GetRows({i}, &tbl);
ASSERT_TRUE(rc.IsOk());
row = tbl.front();
auto r = row.front();
bool cmp = (*t == *r);
EXPECT_TRUE(cmp);
ASSERT_TRUE(cmp);
}
rc = myClient.DestroyCache();
EXPECT_TRUE(rc.IsOk());
rc = myClient->DestroyCache();
ASSERT_TRUE(rc.IsOk());
}

// Simple test with a repeated cache op over random data producer
@@ -187,7 +198,7 @@ TEST_F(MindDataTestCacheOp, TestConcurrencyRequest) {
// |
// RandomDataOp
//
TEST_F(MindDataTestCacheOp, TestRandomDataCache1) {
TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) {
Status rc;
int32_t rank = 0; // not used
MS_LOG(INFO) << "UT test TestRandomDataCache1";
@@ -218,13 +229,18 @@ TEST_F(MindDataTestCacheOp, TestRandomDataCache1) {
.SetDataSchema(std::move(testSchema))
.SetTotalRows(50) // 50 samples for now
.Build(&myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());

// CacheOp
// size of 0, spilling is true
std::shared_ptr<CacheClient> myClient = std::make_shared<CacheClient>(1, 0, true);
CacheClient::Builder builder;
// use arbitrary session of 1, size of 0, spilling// is true
builder.SetSessionId(1).SetCacheMemSz(0).SetSpill(true);
std::shared_ptr<CacheClient> myClient;
rc = builder.Build(&myClient);
ASSERT_TRUE(rc.IsOk());
std::shared_ptr<CacheOp> myCacheOp;

int64_t num_samples = 0;
@@ -236,29 +252,29 @@ TEST_F(MindDataTestCacheOp, TestRandomDataCache1) {
.SetRowsPerBuffer(4)
.SetSampler(std::move(seq_sampler))
.Build(&myCacheOp);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myCacheOp);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());

// RepeatOp
uint32_t numRepeats = 4;
std::shared_ptr<RepeatOp> myRepeatOp;
rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myRepeatOp);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());

// Assign tree relations and root
rc = myRepeatOp->AddChild(myCacheOp);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());
rc = myCacheOp->AddChild(myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());
rc = myTree->AssignRoot(myRepeatOp);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());

MS_LOG(INFO) << "Launching tree and begin iteration";
rc = myTree->Prepare();
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());

// quick check to see what tree looks like
std::ostringstream ss;
@@ -268,24 +284,24 @@ TEST_F(MindDataTestCacheOp, TestRandomDataCache1) {
std::cout << *myClient << std::endl;

rc = myTree->Launch();
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());

// Start the loop of reading tensors from our pipeline
DatasetIterator dI(myTree);
TensorRow tensorList;
rc = dI.FetchNextTensorRow(&tensorList);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());
int rowCount = 0;
while (!tensorList.empty()) {
// Don't display these rows, just count them
MS_LOG(INFO) << "Row fetched #: " << rowCount;
rc = dI.FetchNextTensorRow(&tensorList);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());
rowCount++;
}
ASSERT_EQ(rowCount, 200);
rc = myClient->DestroyCache();
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());
}

//// Simple test with a repeated cache op over random data producer.
@@ -297,7 +313,7 @@ TEST_F(MindDataTestCacheOp, TestRandomDataCache1) {
//// |
//// RandomDataOp
////
TEST_F(MindDataTestCacheOp, TestRandomDataCacheSpill) {
TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) {
Status rc;
int32_t rank = 0; // not used
MS_LOG(INFO) << "UT test TestRandomDataCacheSpill";
@@ -328,15 +344,20 @@ TEST_F(MindDataTestCacheOp, TestRandomDataCacheSpill) {
.SetDataSchema(std::move(testSchema))
.SetTotalRows(10)
.Build(&myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());

// CacheOp
int64_t num_samples = 0;
int64_t start_index = 0;
auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
std::shared_ptr<CacheClient> myClient = std::make_shared<CacheClient>(1, 4, true);
CacheClient::Builder builder;
// use arbitrary session of 1, size of 0, spilling// is true
builder.SetSessionId(1).SetCacheMemSz(4).SetSpill(true);
std::shared_ptr<CacheClient> myClient;
rc = builder.Build(&myClient);
ASSERT_TRUE(rc.IsOk());
std::shared_ptr<CacheOp> myCacheOp;
rc = CacheOp::Builder()
.SetNumWorkers(4)
@@ -344,60 +365,65 @@ TEST_F(MindDataTestCacheOp, TestRandomDataCacheSpill) {
.SetRowsPerBuffer(3)
.SetSampler(std::move(seq_sampler))
.Build(&myCacheOp);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myCacheOp);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());

// RepeatOp
uint32_t numRepeats = 4;
std::shared_ptr<RepeatOp> myRepeatOp;
rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myRepeatOp);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());

// Assign tree relations and root
rc = myRepeatOp->AddChild(myCacheOp);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());
rc = myCacheOp->AddChild(myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());
rc = myTree->AssignRoot(myRepeatOp);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());

MS_LOG(INFO) << "Launching tree and begin iteration";
rc = myTree->Prepare();
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());

std::cout << *myClient << std::endl;

rc = myTree->Launch();
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());

// Start the loop of reading tensors from our pipeline
DatasetIterator dI(myTree);
TensorRow tensorList;
rc = dI.FetchNextTensorRow(&tensorList);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());
int rowCount = 0;
while (!tensorList.empty()) {
// Don't display these rows, just count them
MS_LOG(INFO) << "Row fetched #: " << rowCount;
rc = dI.FetchNextTensorRow(&tensorList);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());
rowCount++;
}
ASSERT_EQ(rowCount, 40);
rc = myClient->DestroyCache();
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());
}

TEST_F(MindDataTestCacheOp, TestImageFolderCacheMerge) {
TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) {
Status rc;
int64_t num_samples = 0;
int64_t start_index = 0;
auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);

std::shared_ptr<CacheClient> myClient = std::make_shared<CacheClient>(1, 0, true);
CacheClient::Builder ccbuilder;
// use arbitrary session of 1, size of 0, spilling// is true
ccbuilder.SetSessionId(1).SetCacheMemSz(0).SetSpill(true);
std::shared_ptr<CacheClient> myClient;
rc = ccbuilder.Build(&myClient);
ASSERT_TRUE(rc.IsOk());

// In a mappable dataset, it uses a complex interactions of cache lookup op and cache merge op.
// Rather than manually build this, the way to do it is to choose the position of the cache in the tree by
@@ -417,44 +443,44 @@ TEST_F(MindDataTestCacheOp, TestImageFolderCacheMerge) {
.SetRecursive(true)
.SetImageFolderDir(datasets_root_path_ + "/testPK/data");
rc = builder.Build(&so);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());

// RepeatOp
uint32_t numRepeats = 4;
std::shared_ptr<RepeatOp> myRepeatOp;
rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());

auto myTree = std::make_shared<ExecutionTree>();
rc = myTree->AssociateNode(so);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());

rc = myTree->AssociateNode(myCacheOp);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());

rc = myTree->AssociateNode(myRepeatOp);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());
rc = myTree->AssignRoot(myRepeatOp);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());

rc = myRepeatOp->AddChild(myCacheOp);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());
rc = myCacheOp->AddChild(so);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());

rc = myTree->Prepare();
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());
rc = myTree->Launch();
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());
// Start the loop of reading tensors from our pipeline
DatasetIterator dI(myTree);
TensorRow tensorList;
rc = dI.FetchNextTensorRow(&tensorList);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());
int rowCount = 0;
while (!tensorList.empty()) {
rc = dI.FetchNextTensorRow(&tensorList);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());
if (rc.IsError()) {
std::cout << rc << std::endl;
break;
@@ -464,7 +490,7 @@ TEST_F(MindDataTestCacheOp, TestImageFolderCacheMerge) {
ASSERT_EQ(rowCount, 176);
std::cout << "Row count : " << rowCount << std::endl;
rc = myClient->DestroyCache();
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());
}

//// Simple test with a repeated cache op over random data producer.
@@ -480,7 +506,7 @@ TEST_F(MindDataTestCacheOp, TestImageFolderCacheMerge) {
//// |
//// RandomDataOp
////
TEST_F(MindDataTestCacheOp, TestCacheInheritSampler) {
TEST_F(MindDataTestCacheOp, DISABLED_TestCacheInheritSampler) {
Status rc;
int32_t rank = 0; // not used
MS_LOG(INFO) << "UT test TestCacheInheritSampler";
@@ -517,57 +543,62 @@ TEST_F(MindDataTestCacheOp, TestCacheInheritSampler) {
.SetTotalRows(10)
.SetSampler(std::move(seq_sampler))
.Build(&myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());

// CacheOp
std::shared_ptr<CacheClient> myClient = std::make_shared<CacheClient>(1, 4, true);
CacheClient::Builder ccbuilder;
// use arbitrary session of 1, size of 0, spilling// is true
ccbuilder.SetSessionId(1).SetCacheMemSz(4).SetSpill(true);
std::shared_ptr<CacheClient> myClient;
rc = ccbuilder.Build(&myClient);
ASSERT_TRUE(rc.IsOk());
std::shared_ptr<CacheOp> myCacheOp;
rc = CacheOp::Builder().SetNumWorkers(4).SetClient(myClient).SetRowsPerBuffer(3).Build(&myCacheOp);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myCacheOp);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());

// RepeatOp
uint32_t numRepeats = 4;
std::shared_ptr<RepeatOp> myRepeatOp;
rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myRepeatOp);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());

// Assign tree relations and root
rc = myRepeatOp->AddChild(myCacheOp);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());
rc = myCacheOp->AddChild(myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());
rc = myTree->AssignRoot(myRepeatOp);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());

MS_LOG(INFO) << "Launching tree and begin iteration";
rc = myTree->Prepare();
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());

std::cout << *myClient << std::endl;

rc = myTree->Launch();
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());

// Start the loop of reading tensors from our pipeline
DatasetIterator dI(myTree);
TensorRow tensorList;
rc = dI.FetchNextTensorRow(&tensorList);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());
int rowCount = 0;
while (!tensorList.empty()) {
// Don't display these rows, just count them
MS_LOG(INFO) << "Row fetched #: " << rowCount;
rc = dI.FetchNextTensorRow(&tensorList);
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());
rowCount++;
}
ASSERT_EQ(rowCount, 40);
rc = myClient->DestroyCache();
EXPECT_TRUE(rc.IsOk());
ASSERT_TRUE(rc.IsOk());
}

+ 7
- 4
tests/ut/python/dataset/test_cache_map.py View File

@@ -15,6 +15,8 @@
"""
Testing cache operator with mappable datasets
"""
import os
import pytest
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as c_vision
from mindspore import log as logger
@@ -25,6 +27,7 @@ DATA_DIR = "../data/dataset/testImageNetData/train/"
GENERATE_GOLDEN = False


@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_basic1():
"""
Test mappable leaf with cache op right over the leaf
@@ -53,7 +56,7 @@ def test_cache_map_basic1():

logger.info("test_cache_map_basic1 Ended.\n")

@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_basic2():
"""
Test mappable leaf with the cache op later in the tree above the map(decode)
@@ -82,7 +85,7 @@ def test_cache_map_basic2():

logger.info("test_cache_map_basic2 Ended.\n")

@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_basic3():
"""
Test a repeat under mappable cache
@@ -116,7 +119,7 @@ def test_cache_map_basic3():
assert num_iter == 8
logger.info('test_cache_basic3 Ended.\n')

@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_basic4():
"""
Test different rows result in core dump
@@ -141,7 +144,7 @@ def test_cache_map_basic4():
assert num_iter == 8
logger.info('test_cache_basic3 Ended.\n')

@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_failure1():
"""
Test nested cache (failure)


+ 26
- 1
tests/ut/python/dataset/test_cache_nomap.py View File

@@ -15,6 +15,8 @@
"""
Testing cache operator with non-mappable datasets
"""
import os
import pytest
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as c_vision
@@ -25,6 +27,7 @@ SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"

GENERATE_GOLDEN = False

@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_basic1():
"""
A random dataset (a non mappable dataset) with a cache over it just after the leaf
@@ -54,6 +57,7 @@ def test_cache_nomap_basic1():
logger.info("test_cache_nomap_basic1 Ended.\n")


@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_basic2():
"""
A random dataset (a non mappable dataset) with a cache over it just after the leaf
@@ -85,6 +89,7 @@ def test_cache_nomap_basic2():
logger.info("test_cache_nomap_basic2 Ended.\n")


@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_basic3():
"""
A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf
@@ -112,9 +117,21 @@ def test_cache_nomap_basic3():

logger.info("Number of data in ds1: {} ".format(num_iter))
assert num_iter == 12

# Contact the server to get the statistics
stat = some_cache.GetStat()
cache_sz = stat.avg_cache_sz
num_mem_cached = stat.num_mem_cached
num_disk_cached = stat.num_disk_cached

logger.info("Number of rows cached in memory: {}".format(num_mem_cached))
logger.info("Number of rows spilled to disk: {}".format(num_disk_cached))
logger.info("Average row cache size: {}".format(cache_sz))

logger.info("test_cache_nomap_basic3 Ended.\n")


@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_basic4():
"""
A TF reader dataset (a non mappable dataset) with a map decode and cache after it
@@ -155,6 +172,7 @@ def test_cache_nomap_basic4():
logger.info("test_cache_nomap_basic4 Ended.\n")


@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_basic5():
"""
A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf
@@ -191,6 +209,7 @@ def test_cache_nomap_basic5():
logger.info("test_cache_nomap_basic5 Ended.\n")


@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_basic6():
"""
A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf
@@ -230,6 +249,7 @@ def test_cache_nomap_basic6():
logger.info("test_cache_nomap_basic6 Ended.\n")


@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_basic7():
"""
A TF reader dataset (a non mappable dataset) that uses global shuffle, and is cached followed by
@@ -265,6 +285,7 @@ def test_cache_nomap_basic7():
logger.info("test_cache_nomap_basic7 Ended.\n")


@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_allowed_share1():
"""
It is allowed to share the cache between the following two trees:
@@ -280,7 +301,7 @@ def test_cache_nomap_allowed_share1():

ds.config.set_seed(1)
# This dataset has 3 records in it only
some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True, prefetch_size=32)
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache)
ds1 = ds1.repeat(4)

@@ -300,6 +321,7 @@ def test_cache_nomap_allowed_share1():
logger.info("test_cache_nomap_allowed_share1 Ended.\n")


@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_allowed_share2():
"""
It is allowed to share the cache between the following two trees (with map decode):
@@ -341,6 +363,7 @@ def test_cache_nomap_allowed_share2():
logger.info("test_cache_nomap_allowed_share2 Ended.\n")


@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_allowed_share3():
"""
It is allowed to share the cache between the following two trees (different shard ids):
@@ -376,6 +399,7 @@ def test_cache_nomap_allowed_share3():
logger.info("test_cache_nomap_allowed_share3 Ended.\n")


@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_allowed_share4():
"""
It is allowed to share the cache between the following two trees:
@@ -414,6 +438,7 @@ def test_cache_nomap_allowed_share4():
logger.info("test_cache_nomap_allowed_share4 Ended.\n")


@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_disallowed_share1():
"""
It is not allowed to share the cache between the following two trees:


Loading…
Cancel
Save