Browse Source

Rebase up to 542a52fbf8

tags/v1.1.0
Lixia Chen Jesse Lee 5 years ago
parent
commit
572c5e5f29
53 changed files with 3545 additions and 526 deletions
  1. +24
    -2
      mindspore/ccsrc/minddata/dataset/engine/cache/CMakeLists.txt
  2. +15
    -6
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc
  3. +2
    -1
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.h
  4. +29
    -15
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc
  5. +7
    -2
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h
  6. +19
    -5
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_common.h
  7. +3
    -2
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc.proto
  8. +27
    -3
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.cc
  9. +2
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.h
  10. +220
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_hw.cc
  11. +81
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_hw.h
  12. +2
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_main.cc
  13. +224
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_numa.cc
  14. +195
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_numa.h
  15. +69
    -81
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_pool.cc
  16. +20
    -14
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_pool.h
  17. +62
    -8
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc
  18. +32
    -18
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h
  19. +266
    -124
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc
  20. +53
    -32
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h
  21. +126
    -103
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc
  22. +13
    -27
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h
  23. +23
    -1
      mindspore/ccsrc/minddata/dataset/engine/cache/de_tensor.fbs
  24. +32
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/perf/CMakeLists.txt
  25. +48
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_msg.cc
  26. +78
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_msg.h
  27. +39
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_perf.cc
  28. +39
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_perf.proto
  29. +575
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_perf_run.cc
  30. +100
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_perf_run.h
  31. +44
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_pipeline.cc
  32. +471
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_pipeline_run.cc
  33. +117
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_pipeline_run.h
  34. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/cache/storage_container.cc
  35. +0
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/storage_container.h
  36. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/cache/storage_manager.cc
  37. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/cache/storage_manager.h
  38. +19
    -20
      mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc
  39. +0
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h
  40. +1
    -2
      mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc
  41. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc
  42. +77
    -1
      mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.cc
  43. +73
    -0
      mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.h
  44. +0
    -3
      mindspore/ccsrc/minddata/dataset/util/CMakeLists.txt
  45. +5
    -0
      mindspore/ccsrc/minddata/dataset/util/allocator.h
  46. +12
    -0
      mindspore/ccsrc/minddata/dataset/util/path.h
  47. +12
    -1
      mindspore/ccsrc/minddata/dataset/util/task.cc
  48. +14
    -1
      mindspore/ccsrc/minddata/dataset/util/task.h
  49. +0
    -1
      tests/ut/cpp/dataset/cache_op_test.cc
  50. +10
    -1
      tests/ut/python/cachetests/cachetest_py.sh
  51. +166
    -47
      tests/ut/python/dataset/test_cache_map.py
  52. +85
    -0
      tests/ut/python/dataset/test_cache_nomap.py
  53. +10
    -0
      tests/ut/python/test_server_stop_testcase.sh

+ 24
- 2
mindspore/ccsrc/minddata/dataset/engine/cache/CMakeLists.txt View File

@@ -1,3 +1,4 @@
add_subdirectory(perf EXCLUDE_FROM_ALL)
include_directories("${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache")
set(MD_FLATBUFFER_OU "${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache")
ms_build_flatbuffers("de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${MD_FLATBUFFER_OU})
@@ -5,6 +6,18 @@ ms_build_flatbuffers("de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engin
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)

# Try to find numa header file and its library
find_file(NUMA_HDR NAMES "numa.h")

if (EXISTS ${NUMA_HDR})
ADD_DEFINITIONS(-DNUMA_ENABLED)
MESSAGE("Numa package found")
endif ()

if (${CMAKE_SYSTEM_NAME} MATCHES "Linux")
ADD_DEFINITIONS(-DCACHE_LOCAL_CLIENT)
endif ()

add_library(engine-cache-client OBJECT
cache_client.cc
cache_fbb.cc
@@ -20,8 +33,13 @@ if (ENABLE_CACHE)
${CACHE_GRPC_SRCS}
cache_grpc_server.cc
cache_arena.cc
cache_hw.cc
cache_numa.cc
cache_pool.cc
cache_service.cc
cache_server.cc)
cache_server.cc
storage_manager.cc
storage_container.cc)

add_executable(cache_server cache_main.cc)
target_link_libraries(cache_server
@@ -39,6 +57,10 @@ if (ENABLE_CACHE)
target_link_libraries(cache_server mindspore::glog)
endif ()

if (EXISTS ${NUMA_HDR})
target_link_libraries(cache_server numa)
endif ()

add_executable(cache_admin cache_admin.cc cache_admin_arg.cc)
target_link_libraries(cache_admin _c_dataengine _c_mindrecord ${PYTHON_LIBRARIES})

@@ -49,7 +71,7 @@ if (ENABLE_CACHE)
add_dependencies(engine-cache-server generated_engine_files)

else ()
ms_protobuf_generate(CACHE_PROTO_SRCS CACHE_PRTO_HDRS cache_grpc.proto)
ms_protobuf_generate(CACHE_PROTO_SRCS CACHE_PROTO_HDRS cache_grpc.proto)
target_sources(engine-cache-client PUBLIC ${CACHE_PROTO_SRCS})
endif ()



+ 15
- 6
mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc View File

@@ -18,6 +18,7 @@
#include <sys/stat.h>
#include <sys/wait.h>
#include <unistd.h>
#include <algorithm>
#include <cerrno>
#include <iomanip>
#include <iostream>
@@ -31,7 +32,9 @@

namespace mindspore {
namespace dataset {

const int32_t CacheAdminArgHandler::kDefaultNumWorkers = std::thread::hardware_concurrency() > 2
? std::thread::hardware_concurrency() / 2
: 1;
const char CacheAdminArgHandler::kServerBinary[] = "cache_server";
const char CacheAdminArgHandler::kDefaultSpillDir[] = "/tmp";

@@ -304,8 +307,10 @@ Status CacheAdminArgHandler::Validate() {
}

// Additional checks here
if (num_workers_ < 1 || num_workers_ > 100)
return Status(StatusCode::kSyntaxError, "Number of workers must be in range of 1 and 100.");
auto max_num_workers = std::max<int32_t>(std::thread::hardware_concurrency(), 100);
if (num_workers_ < 1 || num_workers_ > max_num_workers)
return Status(StatusCode::kSyntaxError,
"Number of workers must be in range of 1 and " + std::to_string(max_num_workers) + ".");
if (log_level_ < 0 || log_level_ > 3) return Status(StatusCode::kSyntaxError, "Log level must be in range (0..3).");
if (memory_cap_ratio_ <= 0 || memory_cap_ratio_ > 1)
return Status(StatusCode::kSyntaxError, "Memory cap ratio should be positive and no greater than 1");
@@ -354,13 +359,15 @@ Status CacheAdminArgHandler::RunCommand() {
std::vector<SessionCacheInfo> session_info = rq->GetSessionCacheInfo();
if (!session_info.empty()) {
std::cout << std::setw(12) << "Session" << std::setw(12) << "Cache Id" << std::setw(12) << "Mem cached"
<< std::setw(12) << "Disk cached" << std::setw(16) << "Avg cache size" << std::endl;
<< std::setw(12) << "Disk cached" << std::setw(16) << "Avg cache size" << std::setw(10) << "Numa hit"
<< std::endl;
for (auto curr_session : session_info) {
std::string cache_id;
std::string stat_mem_cached;
std::string stat_disk_cached;
std::string stat_avg_cached;
int32_t crc = (curr_session.connection_id & 0x00000000FFFFFFFF);
std::string stat_numa_hit;
uint32_t crc = (curr_session.connection_id & 0x00000000FFFFFFFF);
cache_id = (curr_session.connection_id == 0) ? "n/a" : std::to_string(crc);
stat_mem_cached =
(curr_session.stats.num_mem_cached == 0) ? "n/a" : std::to_string(curr_session.stats.num_mem_cached);
@@ -368,10 +375,12 @@ Status CacheAdminArgHandler::RunCommand() {
(curr_session.stats.num_disk_cached == 0) ? "n/a" : std::to_string(curr_session.stats.num_disk_cached);
stat_avg_cached =
(curr_session.stats.avg_cache_sz == 0) ? "n/a" : std::to_string(curr_session.stats.avg_cache_sz);
stat_numa_hit =
(curr_session.stats.num_numa_hit == 0) ? "n/a" : std::to_string(curr_session.stats.num_numa_hit);

std::cout << std::setw(12) << curr_session.session_id << std::setw(12) << cache_id << std::setw(12)
<< stat_mem_cached << std::setw(12) << stat_disk_cached << std::setw(16) << stat_avg_cached
<< std::endl;
<< std::setw(10) << stat_numa_hit << std::endl;
}
} else {
std::cout << "No active sessions." << std::endl;


+ 2
- 1
mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.h View File

@@ -21,6 +21,7 @@
#include <memory>
#include <string>
#include <sstream>
#include <thread>
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/engine/cache/cache_client.h"

@@ -29,7 +30,7 @@ namespace dataset {

class CacheAdminArgHandler {
public:
static constexpr int32_t kDefaultNumWorkers = 32;
static const int32_t kDefaultNumWorkers;
static constexpr int32_t kDefaultSharedMemorySizeInGB = 4;
static constexpr int32_t kDefaultLogLevel = 1;
static constexpr float kMemoryCapRatio = 0.8;


+ 29
- 15
mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc View File

@@ -17,7 +17,6 @@
#include <iomanip>
#include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/engine/cache/cache_request.h"
#include "minddata/dataset/engine/cache/cache_service.h"
#include "minddata/dataset/engine/cache/cache_fbb.h"
#include "minddata/dataset/util/bit.h"

@@ -59,6 +58,7 @@ CacheClient::CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool
: server_connection_id_(0),
cache_mem_sz_(cache_mem_sz),
spill_(spill),
client_id_(-1),
local_bypass_(false),
hostname_(std::move(hostname)),
port_(port),
@@ -71,6 +71,22 @@ CacheClient::CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool

CacheClient::~CacheClient() {
cache_miss_keys_wp_.Set();
if (client_id_ != -1) {
try {
// Send a message to the server, saying I am done.
auto rq = std::make_shared<ConnectResetRequest>(server_connection_id_, client_id_);
Status rc = PushRequest(rq);
if (rc.IsOk()) {
rc = rq->Wait();
if (rc.IsOk()) {
MS_LOG(INFO) << "Disconnect from server successful";
}
}
} catch (const std::exception &e) {
// Can't do anything in destructor. So just log the error.
MS_LOG(ERROR) << e.what();
}
}
(void)comm_->ServiceStop();
}

@@ -85,7 +101,7 @@ void CacheClient::Print(std::ostream &out) const {
}

Status CacheClient::WriteRow(const TensorRow &row, row_id_type *row_id_from_server) const {
auto rq = std::make_shared<CacheRowRequest>(server_connection_id_, cookie(), SupportLocalClient());
auto rq = std::make_shared<CacheRowRequest>(this);
RETURN_IF_NOT_OK(rq->SerializeCacheRowRequest(this, row));
RETURN_IF_NOT_OK(PushRequest(rq));
RETURN_IF_NOT_OK(rq->Wait());
@@ -104,7 +120,7 @@ Status CacheClient::WriteBuffer(std::unique_ptr<DataBuffer> &&in) const {
for (auto i = 0; i < num_rows; ++i) {
TensorRow row;
RETURN_IF_NOT_OK(db_ptr->PopRow(&row));
arr[i] = std::make_shared<CacheRowRequest>(server_connection_id_, cookie(), SupportLocalClient());
arr[i] = std::make_shared<CacheRowRequest>(this);
RETURN_IF_NOT_OK(arr[i]->SerializeCacheRowRequest(this, row));
RETURN_IF_NOT_OK(PushRequest(arr[i]));
}
@@ -118,7 +134,7 @@ Status CacheClient::WriteBuffer(std::unique_ptr<DataBuffer> &&in) const {

Status CacheClient::GetRows(const std::vector<row_id_type> &row_id, TensorTable *out) const {
RETURN_UNEXPECTED_IF_NULL(out);
auto rq = std::make_shared<BatchFetchRequest>(server_connection_id_, row_id, SupportLocalClient());
auto rq = std::make_shared<BatchFetchRequest>(this, row_id);
RETURN_IF_NOT_OK(PushRequest(rq));
RETURN_IF_NOT_OK(rq->Wait());
int64_t mem_addr;
@@ -167,7 +183,7 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) {
lck.Unlock(); // GetStat will grab the mutex again. So unlock it to prevent deadlock.
CacheServiceStat stat{};
RETURN_IF_NOT_OK(GetStat(&stat));
if (stat.cache_service_state == static_cast<uint8_t>(CacheService::State::kFetchPhase)) {
if (stat.cache_service_state == static_cast<uint8_t>(CacheServiceState::kFetchPhase)) {
return Status(StatusCode::kDuplicateKey, __LINE__, __FILE__, "Not an error and we should bypass the build phase");
}
} else {
@@ -183,18 +199,16 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) {
// Start the comm layer to receive reply
RETURN_IF_NOT_OK(comm_->ServiceStart());
// Initiate connection
auto rq = std::make_shared<CreateCacheRequest>(cinfo_, cache_mem_sz_, createFlag);
auto rq = std::make_shared<CreateCacheRequest>(this, cinfo_, cache_mem_sz_, createFlag);
RETURN_IF_NOT_OK(PushRequest(rq));
Status rc = rq->Wait();
if (rc.IsOk() || rc.get_code() == StatusCode::kDuplicateKey) {
std::string cookie;
rq->ParseResult(&server_connection_id_, &cookie);
if (rc.IsOk()) {
// The 1st guy creating the cache will get a cookie back.
// But this object may be shared among pipelines and we don't want
// overwrite it.
cookie_ = cookie;
}
bool success = (rc.IsOk() || rc.get_code() == StatusCode::kDuplicateKey);
// If we get kDuplicateKey, it just means we aren't the first one to create the cache,
// and we will continue to parse the result.
if (rc.get_code() == StatusCode::kDuplicateKey) {
RETURN_IF_NOT_OK(rq->PostReply());
}
if (success) {
// Attach to shared memory for local client
RETURN_IF_NOT_OK(comm_->AttachToSharedMemory(port_, &local_bypass_));
}


+ 7
- 2
mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h View File

@@ -47,6 +47,9 @@ namespace dataset {
class CacheClient {
public:
friend class CacheMergeOp;
friend class CreateCacheRequest;
friend class CacheRowRequest;
friend class BatchFetchRequest;

/// \brief A builder to help creating a CacheClient object
class Builder {
@@ -115,7 +118,7 @@ class CacheClient {
session_id_type GetSessionId() const { return session_id_; }
uint64_t GetCacheMemSz() const { return cache_mem_sz_; }
bool isSpill() const { return spill_; }
const std::string &getHostname() const { return hostname_; }
const std::string &GetHostname() const { return hostname_; }
int32_t GetPort() const { return port_; }
int32_t GetNumConnections() const { return num_connections_; }
int32_t GetPrefetchSize() const { return prefetch_size_; }
@@ -256,8 +259,10 @@ class CacheClient {
CacheClientInfo cinfo_;
// The server_connection_id_ is the actual id we use for operations after the cache is built
connection_id_type server_connection_id_;
// Some magic cookie returned from the cache server.
// Some magic cookie/id returned from the cache server.
std::string cookie_;
int32_t client_id_;
std::vector<int32_t> cpu_list_;
// Comm layer
bool local_bypass_;
std::string hostname_;


+ 19
- 5
mindspore/ccsrc/minddata/dataset/engine/cache/cache_common.h View File

@@ -20,11 +20,6 @@
/// both client and server side codes. Do not put code that is not common here.
/// There are client and server specific header files.

// On platform like Windows, we may support only tcp/ip clients
#if !defined(_WIN32) && !defined(_WIN64)
#define CACHE_LOCAL_CLIENT 1
#endif

#ifdef ENABLE_CACHE
#include <grpcpp/grpcpp.h>
#endif
@@ -50,6 +45,9 @@ constexpr static uint32_t kDataIsInSharedMemory = 2;
/// \brief Size of each message used in message queue.
constexpr static int32_t kSharedMessageSize = 2048;

/// \brief State of CacheService at the server.
enum class CacheServiceState : uint8_t { kNone = 0, kBuildPhase, kFetchPhase, kNoLocking };

/// \brief Convert a Status object into a protobuf
/// \param rc[in] Status object
/// \param reply[in/out] pointer to pre-allocated protobuf object
@@ -61,6 +59,22 @@ inline void Status2CacheReply(const Status &rc, CacheReply *reply) {
/// \param port
/// \return unix socket url
inline std::string PortToUnixSocketPath(int port) { return "/tmp/cache_server_p" + std::to_string(port); }

/// \brief Round up to the next 4k
inline int64_t round_up_4K(int64_t sz) {
// Since 4096 is a power of 2, a simple way to round up is add 4095 and mask off all the
// bits of 4095
return static_cast<uint64_t>(sz + 4095) & ~static_cast<uint64_t>(4095);
}

/// Memory policy
enum CachePoolPolicy : int8_t { kOnNode, kPreferred, kLocal, kInterleave, kNone };

/// Misc typedef
using worker_id_t = int32_t;
using numa_id_t = int32_t;
using cpu_id_t = int32_t;

} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_

+ 3
- 2
mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc.proto View File

@@ -32,12 +32,13 @@ message CacheRequest {
uint32 flag = 2;
oneof connect_info {
// The server_connection_id is the actual id we use for operations after the cache is built
int64 connection_id = 3;
uint64 connection_id = 3;
// But some request like CreateCache we have to use the session id and crc to connect to the server.
CacheClientInfo connection_info = 4;
}
int32 client_id = 5;
// Everything else is just vector of buffers
repeated bytes buf_data = 5;
repeated bytes buf_data = 6;
}

message CacheReply {


+ 27
- 3
mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.cc View File

@@ -74,6 +74,9 @@ Status CacheServerGreeterImpl::Run() {
#if CACHE_LOCAL_CLIENT
RETURN_IF_NOT_OK(CachedSharedMemoryArena::CreateArena(&shm_pool_, port_, shm_pool_sz_in_gb_));
MS_LOG(INFO) << "Creation of local socket and shared memory successful";
auto cs = CacheServer::GetInstance().GetHWControl();
// This shared memory is a hot memory and we will interleave among all the numa nodes.
cs->InterleaveMemory(const_cast<void *>(shm_pool_->SharedMemoryBaseAddr()), shm_pool_sz_in_gb_ * 1073741824L);
#endif
} else {
std::string errMsg = "Fail to start server. ";
@@ -127,8 +130,13 @@ Status CacheServerRequest::operator()(CacheServerGreeter::AsyncService *svc, grp
st_ = STATE::PROCESS;
svc->RequestCacheServerRequest(&ctx_, &rq_, &responder_, cq, cq, this);
} else if (st_ == STATE::PROCESS) {
auto &cs = CacheServer::GetInstance();
// Get a new tag and handle the next request before we serve the current request.
// The tag will be recycled when its state is changed to FINISH
// The tag will be recycled when its state is changed to FINISH.
// The number of free list queues is the same as the number of grpc threads.
// Where we get the free list it doesn't matter (as long we return it back to the right queue).
// We can round robin, use the qid or even use the worker id. We will use the free list queue
// where the current request comes from.
CacheServerRequest *next_rq;
RETURN_IF_NOT_OK(CacheServer::GetFreeRequestTag(myQID, &next_rq));
RETURN_IF_NOT_OK((*next_rq)(svc, cq));
@@ -138,8 +146,24 @@ Status CacheServerRequest::operator()(CacheServerGreeter::AsyncService *svc, grp
type_ = static_cast<RequestType>(rq_.type());
// Now we pass the address of this instance to CacheServer's main loop.
MS_LOG(DEBUG) << "Handle request " << *this;
auto &cs = CacheServer::GetInstance();
RETURN_IF_NOT_OK(cs.PushRequest(myQID, this));
// We will distribute the request evenly (or randomly) over all the numa nodes.
// The exception is BatchFetch which we need to pre-process here.
if (type_ == BaseRequest::RequestType::kBatchFetchRows) {
rc_ = cs.BatchFetchRows(&rq_, &reply_);
if (!rc_.IsInterrupted()) {
Status2CacheReply(rc_, &reply_);
st_ = CacheServerRequest::STATE::FINISH;
responder_.Finish(reply_, grpc::Status::OK, this);
} else {
return rc_;
}
} else {
// When the number of grpc workers is the same as the server workers, we will use this queue id
// and push to the corresponding queue.
bool random = cs.GetNumWorkers() != cs.GetNumGrpcWorkers();
worker_id_t worker_id = random ? cs.GetRandomWorker() : myQID;
RETURN_IF_NOT_OK(cs.PushRequest(worker_id, this));
}
} else if (st_ == STATE::FINISH) {
MS_LOG(DEBUG) << *this << " Finished.";
// Return back to the free list.


+ 2
- 0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.h View File

@@ -16,6 +16,7 @@
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_

#include <atomic>
#include <memory>
#include <string>
#include <utility>
@@ -34,6 +35,7 @@ namespace dataset {
class CacheServerRequest : public BaseRequest {
public:
friend class CacheServer;
friend class CacheService;
enum class STATE : int8_t { CREATE = 1, PROCESS = 2, FINISH = 3 };
explicit CacheServerRequest(int32_t queue_id)
: BaseRequest::BaseRequest(BaseRequest::RequestType::kRequestUnknown),


+ 220
- 0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_hw.cc View File

@@ -0,0 +1,220 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/cache_hw.h"
#ifdef NUMA_ENABLED
#include <numa.h>
#endif
#include <sched.h>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <fstream>
#include <regex>
#include <thread>
#include "utils/log_adapter.h"
namespace mindspore {
namespace dataset {
CacheServerHW::CacheServerHW() {
num_cpus_ = std::thread::hardware_concurrency();
MS_LOG(DEBUG) << "Number of cpu(s) : " << num_cpus_;
#ifdef NUMA_ENABLED
if (numa_enabled()) {
MS_LOG(WARNING) << "Numa support enabled";
for (auto i = 0; i <= numa_max_node(); ++i) {
int64_t free_avail;
int64_t mem_avail = numa_node_size(i, &free_avail);
MS_LOG(INFO) << "Total physical/free RAM in bytes at node " << i << " : " << mem_avail << "/" << free_avail;
}
}
#endif
}

int64_t CacheServerHW::GetTotalSystemMemory() {
auto pages = sysconf(_SC_PHYS_PAGES);
auto page_size = sysconf(_SC_PAGE_SIZE);
auto total = static_cast<int64_t>(pages) * static_cast<int64_t>(page_size);
MS_LOG(INFO) << "Total physical RAM in bytes: " << total;
return total;
}

Status CacheServerHW::SetDefaultMemoryPolicy(CachePoolPolicy policy) {
#ifdef NUMA_ENABLED
if (numa_enabled()) {
// Set our default memory policy.
switch (policy) {
case kLocal:
numa_set_localalloc();
MS_LOG(DEBUG) << "Setting memory default policy to local node. Low level code may override the setting";
break;
case kInterleave:
numa_set_interleave_mask(numa_all_nodes_ptr);
MS_LOG(DEBUG) << "Numa affinity is turned off. Use interleave memory policy as default.";
break;
case kOnNode:
case kPreferred:
RETURN_STATUS_UNEXPECTED("Unsupported memory policy");
break;
case kNone:
default:
// No action taken.
break;
}
}
#endif
return Status::OK();
}

Status CacheServerHW::GetNumaNodeInfo() {
std::set<Path> numa_nodes_;
Path node(kSysNodePath);
auto it = Path::DirIterator::OpenDirectory(&node);
if (it == nullptr) {
MS_LOG(WARNING) << "Unable to open directory " << kSysNodePath << ". Skip scanning hardware info";
return Status::OK();
}
auto isdigit_string = [](const char *str) -> bool {
bool r = true;
for (auto i = 0; i < strlen(str); ++i) {
if (!std::isdigit(str[i])) {
r = false;
break;
}
}
return r;
};
// Look for name starts with 'node' and followed by digits.
const char kNodeName[] = "node";
while (it->hasNext()) {
auto p = it->next();
const std::string entry = p.Basename();
const char *name = entry.data();
if (strncmp(name, kNodeName, 4) == 0 && isdigit_string(name + strlen(kNodeName))) {
numa_nodes_.insert(p);
}
}
// There should be at least one. But if not found in any case, just move on the
// rest of the server start up.
if (numa_nodes_.empty()) {
MS_LOG(WARNING) << "No numa nodes ? Skip scanning hardware info";
return Status::OK();
}
// For each numa node, get a list of CPU that is associated with it.
const char kCpuList[] = "cpulist";
auto r = std::regex("[0-9]*-[0-9]*");
for (Path p : numa_nodes_) {
auto node_dir = p.Basename().data();
numa_id_t numa_node = strtol(node_dir + strlen(kNodeName), nullptr, 10);
Path f = p / kCpuList;
std::ifstream fs(f.toString());
CHECK_FAIL_RETURN_UNEXPECTED(!fs.fail(), "Fail to open file: " + f.toString());
std::string cpu_string;
cpu_set_t cpuset;
CPU_ZERO(&cpuset);
int32_t cpu_cnt = 0;
while (getline(fs, cpu_string)) {
// Now we parse the content of cpu_string.
std::sregex_iterator iter(cpu_string.begin(), cpu_string.end(), r);
std::sregex_iterator end;
while (iter != end) {
auto match = iter->str();
auto pos = match.find_first_of('-');
std::string min = match.substr(0, pos);
std::string max = match.substr(pos + 1);
cpu_id_t cpu_min = strtol(min.data(), nullptr, 10);
cpu_id_t cpu_max = strtol(max.data(), nullptr, 10);
MS_LOG(DEBUG) << "Numa node " << numa_node << " CPU(s) : " << cpu_min << "-" << cpu_max;
for (int i = cpu_min; i <= cpu_max; ++i) {
CPU_SET(i, &cpuset);
++cpu_cnt;
}
++iter;
}
}
CHECK_FAIL_RETURN_UNEXPECTED(!fs.bad(), "Fail to read file: " + f.toString());
fs.close();
// Remember which cpu is attached to this numa node.
numa_cpuset_.emplace(numa_node, cpuset);
numa_cpu_cnt_.emplace(numa_node, cpu_cnt);
}
MS_LOG(DEBUG) << "Number of numa nodes : " << numa_cpuset_.size();
return Status::OK();
}

Status CacheServerHW::SetAffinity(const Task &tk, numa_id_t numa_node) {
auto r = numa_cpuset_.find(numa_node);
if (r != numa_cpuset_.end()) {
auto err = pthread_setaffinity_np(tk.GetNativeHandle(), sizeof(r->second), &r->second);
if (err) {
std::string errMsg = "Unable to set affiity. Errno = " + std::to_string(errno);
RETURN_STATUS_UNEXPECTED(errMsg);
}
} else {
RETURN_STATUS_UNEXPECTED("Numa node " + std::to_string(numa_node) + " not found");
}
return Status::OK();
}

std::vector<cpu_id_t> CacheServerHW::GetCpuList(numa_id_t numa_id) {
std::vector<cpu_id_t> v;
auto it = numa_cpuset_.find(numa_id);
if (it != numa_cpuset_.end()) {
auto &cpu_set = it->second;
for (auto i = 0; i < num_cpus_; ++i) {
if (CPU_ISSET(i, &cpu_set)) {
v.push_back(i);
}
}
}
return v;
}

numa_id_t CacheServerHW::GetMyNode() const {
numa_id_t node_id = 0;
auto cpu = sched_getcpu();
#ifdef NUMA_ENABLED
node_id = numa_node_of_cpu(cpu);
#else
bool found = false;
for (auto it : numa_cpuset_) {
cpu_set_t &cpu_set = it.second;
if (CPU_ISSET(cpu, &cpu_set)) {
node_id = it.first;
found = true;
break;
}
}
MS_LOG(DEBUG) << "cpu id " << cpu << " found : " << std::boolalpha << found;
#endif
return node_id;
}

void CacheServerHW::InterleaveMemory(void *ptr, size_t sz) {
#ifdef NUMA_ENABLED
if (numa_enabled()) {
numa_interleave_memory(ptr, sz, numa_all_nodes_ptr);
}
#endif
}

bool CacheServerHW::numa_enabled() {
#ifdef NUMA_ENABLED
return (numa_available() != -1);
#else
return false;
#endif
}
} // namespace dataset
} // namespace mindspore

+ 81
- 0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_hw.h View File

@@ -0,0 +1,81 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_HW_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_HW_H_

#ifdef NUMA_ENABLED
#include <numa.h>
#endif
#include <sched.h>
#include <stdlib.h>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "minddata/dataset/engine/cache/cache_common.h"
#include "minddata/dataset/util/memory_pool.h"
#include "minddata/dataset/util/path.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/util/task.h"

namespace mindspore {
namespace dataset {
class CacheServerHW {
public:
CacheServerHW();
~CacheServerHW() = default;

/// \brief Get Numa node info without using numa library
/// \return Status object
Status GetNumaNodeInfo();

/// \brief Set thread affinity
Status SetAffinity(const Task &tk, numa_id_t numa_node);

/// \brief Get total number of cpu(s)
int32_t GetCpuCount() const { return num_cpus_; }

/// \brief Get total number of numa nodes
int32_t GetNumaNodeCount() const { return numa_cpuset_.empty() ? 1 : numa_cpuset_.size(); }

/// \brief Get a list of cpu for a given numa node.
std::vector<cpu_id_t> GetCpuList(numa_id_t numa_id);

static bool numa_enabled();

/// \brief Return the numa the current thread is running on.
numa_id_t GetMyNode() const;

/// \brief Interleave a given memory block. Used by shared memory only.
static void InterleaveMemory(void *ptr, size_t sz);

/// \brief Set default memory policy.
static Status SetDefaultMemoryPolicy(CachePoolPolicy);

/// \brief This returns the size (in bytes) of the physical RAM on the machine.
/// \return the size (in bytes) of the physical RAM on the machine.
static int64_t GetTotalSystemMemory();

private:
constexpr static char kSysNodePath[] = "/sys/devices/system/node";
int32_t num_cpus_;
std::map<numa_id_t, cpu_set_t> numa_cpuset_;
std::map<numa_id_t, int32_t> numa_cpu_cnt_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_HW_H_

+ 2
- 0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_main.cc View File

@@ -54,6 +54,8 @@ ds::Status SendSyncCommand(int32_t port, ds::BaseRequest::RequestType type, ds::
#endif
try {
rq->set_type(static_cast<int16_t>(type));
rq->set_client_id(-1);
rq->set_flag(0);
grpc::ChannelArguments args;
grpc::ClientContext ctx;
grpc::CompletionQueue cq;


+ 224
- 0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_numa.cc View File

@@ -0,0 +1,224 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include <iterator>
#include <limits>
#include "minddata/dataset/engine/cache/cache_hw.h"
#include "minddata/dataset/engine/cache/cache_numa.h"
#include "minddata/dataset/util/random.h"
namespace mindspore {
namespace dataset {
NumaMemoryPool::NumaMemoryPool(std::shared_ptr<CacheServerHW> hw, float memory_cap_ratio)
: hw_(std::move(hw)), memory_cap_ratio_(memory_cap_ratio) {
int64_t total_avail = 0;
// We will create a number of small Arenas to spread out the server threads so it
// will be less contention. If we link with the numa library, i.e. if
// NUMA_ENABLED is defined, we will make use of the low level numa library such that
// each Arena solely comes from one particular socket.
// The total number of Arenas will be controlled under the number of cpus.
auto num_cpus = hw_->GetCpuCount();
memory_segments_.reserve(num_cpus);
arena_list_.reserve(num_cpus);
mux_ = std::make_unique<std::mutex[]>(num_cpus);
auto num_memory_nodes = num_cpus;
int64_t max_avail = CacheServerHW::GetTotalSystemMemory() * memory_cap_ratio_;
int64_t arena_sz = max_avail / num_memory_nodes;
// If arena_sz is too small, lower the number of Arenas.
if (arena_sz < std::numeric_limits<int32_t>::max()) {
arena_sz = round_up_4K(std::numeric_limits<int32_t>::max());
num_memory_nodes = max_avail / arena_sz;
if (num_memory_nodes == 0) {
num_memory_nodes = 1;
arena_sz = max_avail;
}
}
MS_LOG(INFO) << "Creating " << num_memory_nodes << " number of arena. Each one of size " << arena_sz;

#ifdef NUMA_ENABLED
if (numa_available() != -1) {
auto num_numa_nodes = hw_->GetNumaNodeCount();
numa_id_t node_id = 0;
for (auto i = 0; i < num_memory_nodes; ++i) {
auto success = CreateMultipleArenas(arena_sz, node_id++ % num_numa_nodes, 1);
total_avail += success * arena_sz;
}
} else {
auto success = CreateMultipleArenas(arena_sz, 0, num_memory_nodes);
total_avail += success * arena_sz;
}
#else
auto success = CreateMultipleArenas(arena_sz, 0, num_memory_nodes);
total_avail += success * arena_sz;
#endif
memory_cap_ = total_avail;
MS_LOG(WARNING) << "Memory pool created. Total available memory " << memory_cap_ << " spread in " << nodes_.size()
<< " arenas";
int32_t slot = 0;
// Set up a map for future easy access.
for (auto node_id : nodes_) {
numa_map_[node_id].push_back(slot);
++slot;
}
}

int32_t NumaMemoryPool::CreateMultipleArenas(int64_t segment_sz, numa_id_t node_id, int32_t repeat_count) {
int32_t success = 0;
for (auto i = 0; i < repeat_count; ++i) {
#ifdef NUMA_ENABLED
void *ptr = numa_alloc_onnode(segment_sz, node_id);
#else
void *ptr = malloc(segment_sz);
#endif
if (ptr != nullptr) {
memory_segments_.emplace_back(ptr, segment_sz);
arena_list_.push_back(std::make_unique<ArenaImpl>(ptr, segment_sz));
nodes_.push_back(node_id);
++success;
} else {
// Skip the rest.
break;
}
}
MS_LOG(DEBUG) << "Allocate " << success << " arenas from node " << node_id;
return success;
}

NumaMemoryPool::~NumaMemoryPool() {
if (!memory_segments_.empty()) {
for (auto &s : memory_segments_) {
#ifdef NUMA_ENABLED
numa_free(s.first, s.second);
#else
free(s.first);
#endif
}
}
}

Status NumaMemoryPool::Allocate(size_t n, void **p) {
RETURN_UNEXPECTED_IF_NULL(p);
auto mt = GetRandomDevice();
Status rc;
void *ptr = nullptr;
auto num_segments = memory_segments_.size();
CHECK_FAIL_RETURN_UNEXPECTED(num_segments > 0, "No numa nodes available");
if (NumaAware()) {
auto num_numa_nodes = hw_->GetNumaNodeCount();
// We will start from the numa node this worker id is running on and do a round robin search.
numa_id_t start = hw_->GetMyNode();
numa_id_t node_id = start;
do {
auto it = numa_map_.find(node_id);
if (it != numa_map_.end()) {
auto &slots = it->second;
auto num_slots = slots.size();
std::uniform_int_distribution<int32_t> distribution(0, num_slots - 1);
auto start_slot = distribution(mt);
int32_t inx = start_slot;
do {
int32_t k = slots.at(inx);
std::unique_lock lock_x(mux_[k]);
auto &impl = arena_list_.at(k);
rc = impl->Allocate(n, &ptr);
if (rc.IsOk()) {
*p = ptr;
break;
} else if (rc.IsOutofMemory()) {
inx = (inx + 1) % num_slots;
} else {
return rc;
}
} while (inx != start_slot);
}
// We have done searching for this numa node. If not found, move to the next node.
if (ptr == nullptr) {
node_id = (node_id + 1) % num_numa_nodes;
} else {
break;
}
} while (node_id != start);
} else {
// If not numa aware, just randomly pick a slot.
std::uniform_int_distribution<int32_t> distribution(0, num_segments - 1);
auto start_slot = distribution(mt);
int32_t slot = start_slot;
do {
std::unique_lock lock_x(mux_[slot]);
auto &impl = arena_list_.at(slot);
rc = impl->Allocate(n, &ptr);
if (rc.IsOk()) {
*p = ptr;
break;
} else if (rc.IsOutofMemory()) {
// Make the next arena and continue.
slot = (slot + 1) % num_segments;
} else {
return rc;
}
} while (slot != start_slot);
}
// Handle the case we have done one round robin search.
if (ptr == nullptr) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
}
return rc;
}

void NumaMemoryPool::Deallocate(void *p) {
// Find out which numa slot it comes from.
auto slot = Locate(p);
MS_ASSERT(slot != -1);
std::unique_lock lock_x(mux_[slot]);
auto &impl = arena_list_.at(slot);
impl->Deallocate(p);
}

int NumaMemoryPool::PercentFree() const {
int percent_free = 0;
int num_arena = 0;
for (auto const &p : arena_list_) {
percent_free += p->PercentFree();
num_arena++;
}
if (num_arena) {
return percent_free / num_arena;
} else {
return 100;
}
}

int32_t NumaMemoryPool::Locate(void *p) const {
int32_t slot = 0;
char *mem = reinterpret_cast<char *>(p);
for (slot = 0; slot < memory_segments_.size(); ++slot) {
auto elem = memory_segments_.at(slot);
char *q = reinterpret_cast<char *>(elem.first);
if (mem >= q && mem < q + elem.second) {
return slot;
}
}
return -1;
}

std::vector<numa_id_t> NumaMemoryPool::GetAvailableNodes() const {
std::vector<numa_id_t> v;
std::transform(numa_map_.begin(), numa_map_.end(), std::back_inserter(v),
[](const std::pair<numa_id_t, std::vector<int32_t>> &v) { return v.first; });
return v;
}

} // namespace dataset
} // namespace mindspore

+ 195
- 0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_numa.h View File

@@ -0,0 +1,195 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_NUMA_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_NUMA_H_

#include <limits>
#include <map>
#include <memory>
#include <mutex>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/cache/cache_hw.h"
#include "minddata/dataset/util/arena.h"
#include "minddata/dataset/util/memory_pool.h"

namespace mindspore {
namespace dataset {
/// \brief An allocator but for a particular numa node.
template <typename T>
class NumaAllocator {
public:
explicit NumaAllocator(numa_id_t node_id, CachePoolPolicy policy)
: policy_(policy), numa_enabled_(false), node_id_(node_id) {
#ifdef NUMA_ENABLED
numa_enabled_ = numa_available() != -1;
#endif
}
~NumaAllocator() = default;

template <typename U>
explicit NumaAllocator(NumaAllocator<U> const &rhs)
: policy_(rhs.policy_), numa_enabled_(rhs.numa_enabled_), node_id_(rhs.node_id_) {}

template <typename U>
bool operator==(Allocator<U> const &rhs) const {
return node_id_ == rhs.node_id_;
}

template <typename U>
bool operator!=(Allocator<U> const &rhs) const {
return node_id_ != rhs.node_id_;
}

template <typename U>
friend class NumaAllocator;

using value_type = T;
using pointer = T *;
using const_pointer = const T *;
using reference = T &;
using const_reference = const T &;
using size_type = uint64_t;
using difference_type = std::ptrdiff_t;

template <typename U>
struct rebind {
using other = Allocator<U>;
};

using propagate_on_container_copy_assignment = std::true_type;
using propagate_on_container_move_assignment = std::true_type;
using propagate_on_container_swap = std::true_type;

/// Allocate memory on this node only. Return nullptr if no memory on this numa node.
/// \note. This version will not throw if we can't allocate memory from this node.
/// User must check if the pointer returned is null or not.
pointer allocate(std::size_t n) noexcept {
auto sz = n * sizeof(T);
void *p = nullptr;
#ifdef NUMA_ENABLED
if (numa_enabled_) {
switch (policy_) {
case kPreferred:
numa_set_preferred(node_id_);
p = numa_alloc(sz);
break;
case kLocal:
p = numa_alloc_local(sz);
break;
case kInterleave:
p = numa_alloc_interleaved(sz);
break;
case kOnNode:
p = numa_alloc_onnode(sz, node_id_);
break;
case kNone:
default:
p = numa_alloc(sz);
break;
}
} else {
p = malloc(sz);
}
#else
p = malloc(sz);
#endif
return reinterpret_cast<pointer>(p);
}

/// Free a memory allocated on this node.
void deallocate(pointer p, std::size_t n) noexcept {
#ifdef NUMA_ENABLED
if (numa_enabled_) {
numa_free(p, n * sizeof(T));
} else {
free(p);
}
#else
free(p);
#endif
}

/// \brief Allow one to change to another numa node
void SetNodeId(numa_id_t node_id) { node_id_ = node_id; }

/// \brif Getter for node_id;
numa_id_t GetNodeId() const { return node_id_; }

/// \brief Getter for policy
CachePoolPolicy GetPolicy() const { return policy_; }

private:
CachePoolPolicy policy_;
bool numa_enabled_;
numa_id_t node_id_;
};

/// \brief A NumaMemoryPool is like a CircularPool but all the arenas have already been allocated
/// and each one comes from a numa socket. Memory is allocated using OnNode policy. That is,
/// it is solely comes from one particular numa node, and is not interleaved.
class NumaMemoryPool : public MemoryPool {
public:
explicit NumaMemoryPool(std::shared_ptr<CacheServerHW> hw, float memory_cap_ratio);
~NumaMemoryPool() override;

// As a derived class, we override the following functions
Status Allocate(size_t size, void **pVoid) override;
void Deallocate(void *pVoid) override;
Status Reallocate(void **pVoid, size_t old_sz, size_t new_sz) override { RETURN_STATUS_UNEXPECTED("Not supported"); }
uint64_t get_max_size() const override { return std::numeric_limits<uint64_t>::max(); }
int PercentFree() const override;

/// \brief Return if the memory pool is numa aware
bool NumaAware() const { return CacheServerHW::numa_enabled(); }

/// \brief. This returns all the numa nodes that we are able to allocate memory from.
std::vector<numa_id_t> GetAvailableNodes() const;

/// \brief. Given a pointer (allocated from this pool), return the numa node where it is located.
/// \note. -1 is returned if not found.
numa_id_t FindNode(void *p) const {
auto slot = Locate(p);
if (slot != -1) {
return nodes_.at(slot);
} else {
return -1;
}
}

/// \brief Return maximum available memory
int64_t GetAvailableMemory() const { return memory_cap_; }

private:
std::shared_ptr<CacheServerHW> hw_;
float memory_cap_ratio_;
int64_t memory_cap_;
std::vector<std::pair<void *, int64_t>> memory_segments_;
std::vector<std::unique_ptr<ArenaImpl>> arena_list_;
std::unique_ptr<std::mutex[]> mux_;
std::vector<numa_id_t> nodes_;
std::map<numa_id_t, std::vector<int32_t>> numa_map_;

/// \brief. Returns the slot that a given memory comes from.
/// \return slot from numa_segments. -1 if not found.
int32_t Locate(void *p) const;

/// If numa library is not linked, or numa_availble() return -1, we will fall back to this method.
int32_t CreateMultipleArenas(int64_t segment_sz, numa_id_t node_id, int32_t repeat_count);
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_NUMA_H_

mindspore/ccsrc/minddata/dataset/util/cache_pool.cc → mindspore/ccsrc/minddata/dataset/engine/cache/cache_pool.cc View File

@@ -15,18 +15,14 @@
*/
#include <algorithm>
#include "utils/ms_utils.h"
#include "minddata/dataset/util/cache_pool.h"
#include "minddata/dataset/engine/cache/cache_pool.h"
#include "minddata/dataset/engine/cache/cache_server.h"
#include "minddata/dataset/util/services.h"

namespace mindspore {
namespace dataset {
CachePool::CachePool(const value_allocator &alloc, bool ourOwnArena, const std::string &root)
: alloc_(alloc),
root_(root),
subfolder_(Services::GetUniqueID()),
sm_(nullptr),
tree_(nullptr),
custom_arena_(ourOwnArena) {}
CachePool::CachePool(std::shared_ptr<NumaMemoryPool> mp, const std::string &root)
: mp_(std::move(mp)), root_(root), subfolder_(Services::GetUniqueID()), sm_(nullptr), tree_(nullptr) {}

Status CachePool::DoServiceStart() {
tree_ = std::make_shared<data_index>();
@@ -36,10 +32,11 @@ Status CachePool::DoServiceStart() {
RETURN_IF_NOT_OK(spill.CreateDirectories());
sm_ = std::make_shared<StorageManager>(spill);
RETURN_IF_NOT_OK(sm_->ServiceStart());
MS_LOG(INFO) << "CachePool will use disk folder: " << common::SafeCStr(spill.toString());
MS_LOG(INFO) << "CachePool will use disk folder: " << spill.toString();
}
return Status::OK();
}

Status CachePool::DoServiceStop() {
Status rc;
Status rc2;
@@ -50,14 +47,14 @@ Status CachePool::DoServiceStop() {
}
}
sm_.reset();
// If it is our own arena, skip freeing individual pieces.
if (!custom_arena_) {
for (auto &bl : *tree_) {
if (bl.ptr != nullptr) {
alloc_.deallocate(bl.ptr, bl.sz);
}

value_allocator alloc(mp_);
for (auto &bl : *tree_) {
if (bl.ptr != nullptr) {
alloc.deallocate(bl.ptr, bl.sz);
}
}

tree_.reset();
if (!root_.toString().empty()) {
Path spill = GetSpillPath();
@@ -75,8 +72,10 @@ Status CachePool::DoServiceStop() {
}
return rc2;
}

CachePool::~CachePool() noexcept { (void)ServiceStop(); }
Status CachePool::Insert(CachePool::key_type key, const std::vector<ReadableSlice> &buf, bool writeToDiskDirectly) {

Status CachePool::Insert(CachePool::key_type key, const std::vector<ReadableSlice> &buf) {
DataLocator bl;
Status rc;
size_t sz = 0;
@@ -85,26 +84,35 @@ Status CachePool::Insert(CachePool::key_type key, const std::vector<ReadableSlic
sz += v.GetSize();
}
bl.sz = sz;
try {
if (!writeToDiskDirectly) {
bl.ptr = alloc_.allocate(sz);
// We will do a piecewise copy.
WritableSlice dest(bl.ptr, bl.sz);
size_t pos = 0;
for (auto &v : buf) {
WritableSlice out(dest, pos);
rc = WritableSlice::Copy(&out, v);
if (rc.IsError()) {
break;
}
pos += v.GetSize();
}
rc = mp_->Allocate(sz, reinterpret_cast<void **>(&bl.ptr));
if (rc.IsOk()) {
// Write down which numa node where we allocate from. It only make sense if the policy is kOnNode.
if (CacheServerHW::numa_enabled()) {
auto &cs = CacheServer::GetInstance();
auto node_id = cs.GetHWControl()->GetMyNode();
bl.node_id = mp_->FindNode(bl.ptr);
CHECK_FAIL_RETURN_UNEXPECTED(bl.node_id != -1, "Allocator is not from numa memory pool");
bl.node_hit = (bl.node_id == node_id);
}
// We will do a piecewise copy.
WritableSlice dest(bl.ptr, bl.sz);
size_t pos = 0;
for (auto &v : buf) {
WritableSlice out(dest, pos);
rc = WritableSlice::Copy(&out, v);
if (rc.IsError()) {
alloc_.deallocate(bl.ptr, sz);
bl.ptr = nullptr;
return rc;
break;
}
} else if (sm_ != nullptr) {
pos += v.GetSize();
}
if (rc.IsError()) {
mp_->Deallocate(bl.ptr);
bl.ptr = nullptr;
return rc;
}
} else if (rc.IsOutofMemory()) {
// If no memory, write to disk.
if (sm_ != nullptr) {
MS_LOG(DEBUG) << "Spill to disk directly ... " << bl.sz << " bytes.";
RETURN_IF_NOT_OK(sm_->Write(&bl.storage_key, buf));
} else {
@@ -112,12 +120,8 @@ Status CachePool::Insert(CachePool::key_type key, const std::vector<ReadableSlic
// instead.
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
}
} catch (std::bad_alloc &e) {
if (sm_ != nullptr) {
RETURN_IF_NOT_OK(sm_->Write(&bl.storage_key, buf));
} else {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
}
} else {
return rc;
}
// Insert into the B+ tree. We may still get out of memory error. So need to catch it.
try {
@@ -127,10 +131,13 @@ Status CachePool::Insert(CachePool::key_type key, const std::vector<ReadableSlic
}
// Duplicate key is treated as error and we will also free the memory.
if (rc.IsError() && bl.ptr != nullptr) {
alloc_.deallocate(bl.ptr, sz);
mp_->Deallocate(bl.ptr);
bl.ptr = nullptr;
return rc;
}
return rc;
}

Status CachePool::Read(CachePool::key_type key, WritableSlice *dest, size_t *bytesRead) const {
RETURN_UNEXPECTED_IF_NULL(dest);
auto r = tree_->Search(key);
@@ -156,13 +163,14 @@ Status CachePool::Read(CachePool::key_type key, WritableSlice *dest, size_t *byt
}
return Status::OK();
}
const CachePool::value_allocator &CachePool::get_allocator() const { return alloc_; }
Path CachePool::GetSpillPath() const {
auto spill = Path(root_) / subfolder_;
return spill;
}

CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const {
CacheStat cs{-1, -1, 0, 0, 0};
CacheStat cs{-1, -1, 0, 0, 0, 0};
int64_t total_sz = 0;
if (tree_->begin() != tree_->end()) {
cs.min_key = tree_->begin().key();
@@ -174,6 +182,9 @@ CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const {
} else {
++cs.num_disk_cached;
}
if (it.value().node_hit) {
++cs.num_numa_hit;
}
auto cur_key = it.key();
if (GetMissingKeys) {
for (auto i = cs.max_key + 1; i < cur_key; ++i) {
@@ -192,49 +203,26 @@ CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const {
}
return cs;
}
Status CachePool::Spill(CachePool::DataLocator *dl) {
if (sm_ == nullptr) {
RETURN_STATUS_UNEXPECTED("No disk storage to spill");
}
RETURN_UNEXPECTED_IF_NULL(dl);
RETURN_UNEXPECTED_IF_NULL(dl->ptr);
if (dl->storage_key == 0) {
ReadableSlice data(dl->ptr, dl->sz);
RETURN_IF_NOT_OK(sm_->Write(&dl->storage_key, {data}));
}
alloc_.deallocate(dl->ptr, dl->sz);
dl->ptr = nullptr;
return Status::OK();
}
Status CachePool::Locate(CachePool::DataLocator *dl) {
RETURN_UNEXPECTED_IF_NULL(dl);
if (dl->ptr == nullptr) {
if (sm_ == nullptr) {
RETURN_STATUS_UNEXPECTED("No disk storage to locate the data");
}
try {
dl->ptr = alloc_.allocate(dl->sz);
WritableSlice dest(dl->ptr, dl->sz);
Status rc = Read(dl->storage_key, &dest);
if (rc.IsError()) {
alloc_.deallocate(dl->ptr, dl->sz);
dl->ptr = nullptr;
return rc;
}
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
}
}
return Status::OK();
}
size_t CachePool::GetSize(CachePool::key_type key) const {

Status CachePool::GetDataLocator(key_type key, const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb,
flatbuffers::Offset<DataLocatorMsg> *out) const {
RETURN_UNEXPECTED_IF_NULL(out);
auto r = tree_->Search(key);
if (r.second) {
auto &it = r.first;
return it->sz;
DataLocatorMsgBuilder bld(*fbb);
bld.add_key(key);
bld.add_size(it->sz);
bld.add_node_id(it->node_id);
bld.add_addr(reinterpret_cast<int64_t>(it->ptr));
auto offset = bld.Finish();
*out = offset;
} else {
return 0;
// Key not in the cache.
auto offset = CreateDataLocatorMsg(*fbb, key, 0, 0, 0);
*out = offset;
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

mindspore/ccsrc/minddata/dataset/util/cache_pool.h → mindspore/ccsrc/minddata/dataset/engine/cache/cache_pool.h View File

@@ -19,11 +19,14 @@
#include <memory>
#include <mutex>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/cache/cache_common.h"
#include "minddata/dataset/engine/cache/cache_numa.h"
#include "minddata/dataset/engine/cache/storage_manager.h"
#include "minddata/dataset/util/allocator.h"
#include "minddata/dataset/util/service.h"
#include "minddata/dataset/util/slice.h"
#include "minddata/dataset/util/storage_manager.h"
#include "minddata/dataset/util/auto_index.h"
#include "minddata/dataset/util/btree.h"

@@ -45,13 +48,15 @@ class CachePool : public Service {
// An internal class to locate the whereabouts of a backed up buffer which can be either in
class DataLocator {
public:
DataLocator() : ptr(nullptr), sz(0), storage_key(0) {}
DataLocator() : ptr(nullptr), sz(0), node_id(0), node_hit(false), storage_key(0) {}
~DataLocator() = default;
DataLocator(const DataLocator &other) = default;
DataLocator &operator=(const DataLocator &other) = default;
DataLocator(DataLocator &&other) noexcept {
ptr = other.ptr;
sz = other.sz;
node_id = other.node_id;
node_hit = other.node_hit;
storage_key = other.storage_key;
other.ptr = nullptr;
other.sz = 0;
@@ -61,6 +66,8 @@ class CachePool : public Service {
if (&other != this) {
ptr = other.ptr;
sz = other.sz;
node_id = other.node_id;
node_hit = other.node_hit;
storage_key = other.storage_key;
other.ptr = nullptr;
other.sz = 0;
@@ -70,6 +77,8 @@ class CachePool : public Service {
}
pointer ptr;
size_t sz;
numa_id_t node_id; // where the numa node the memory is allocated to
bool node_hit; // we can allocate to the preferred node
StorageManager::key_type storage_key;
};

@@ -85,19 +94,20 @@ class CachePool : public Service {
int64_t num_mem_cached;
int64_t num_disk_cached;
int64_t average_cache_sz;
int64_t num_numa_hit;
std::vector<key_type> gap;
};

/// \brief Constructor
/// \param alloc Allocator to allocate memory from
/// \param root Optional disk folder to spill
explicit CachePool(const value_allocator &alloc, bool customArena, const std::string &root = "");
explicit CachePool(std::shared_ptr<NumaMemoryPool> mp, const std::string &root = "");

CachePool(const CachePool &) = delete;
CachePool(CachePool &&) = delete;
CachePool &operator=(const CachePool &) = delete;
CachePool &operator=(CachePool &&) = delete;
~CachePool() noexcept;
~CachePool() noexcept override;

Status DoServiceStart() override;
Status DoServiceStop() override;
@@ -110,7 +120,8 @@ class CachePool : public Service {
/// \param[in] buf A sequence of ReadableSlice objects.
/// \param[in] writeToDiskDirectly If true, no spill to disk if spill is enabled, or return no memory
/// \return Error code
Status Insert(key_type key, const std::vector<ReadableSlice> &buf, bool writeToDiskDirectly);
Status Insert(CachePool::key_type key, const std::vector<ReadableSlice> &buf);

/// \brief Restore a cached buffer (from memory or disk)
/// \param[in] key A previous key returned from Insert
/// \param[out] dest The cached buffer will be copied to this destination represented by a WritableSlice
@@ -118,18 +129,14 @@ class CachePool : public Service {
/// \return Error code
Status Read(key_type key, WritableSlice *dest, size_t *bytesRead = nullptr) const;

Status Spill(DataLocator *dl);

Status Locate(DataLocator *dl);

size_t GetSize(key_type key) const;
/// \brief Serialize a DataLocator
Status GetDataLocator(key_type, const std::shared_ptr<flatbuffers::FlatBufferBuilder> &,
flatbuffers::Offset<DataLocatorMsg> *) const;

/// \brief Get statistics.
/// \return CacheStat object
CacheStat GetStat(bool GetMissingKeys = false) const;

const value_allocator &get_allocator() const;

std::string MyName() const { return subfolder_; }

/// \brief Toggle locking
@@ -137,12 +144,11 @@ class CachePool : public Service {
void SetLocking(bool on_off) { tree_->SetLocking(on_off); }

private:
value_allocator alloc_;
std::shared_ptr<NumaMemoryPool> mp_;
Path root_;
const std::string subfolder_;
std::shared_ptr<StorageManager> sm_;
std::shared_ptr<data_index> tree_;
bool custom_arena_;
};
} // namespace dataset
} // namespace mindspore

+ 62
- 8
mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc View File

@@ -14,6 +14,11 @@
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/cache_request.h"
#if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID)
#include <sched.h>
#include <sys/types.h>
#include <unistd.h>
#endif
#include <cstdlib>
#include <thread>
#include "minddata/dataset/core/constants.h"
@@ -106,6 +111,7 @@ Status CacheRowRequest::PostReply() {
}
return Status::OK();
}

Status CacheRowRequest::Prepare() {
if (BitTest(rq_.flag(), kDataIsInSharedMemory)) {
// First one is cookie, followed by address and then size.
@@ -118,10 +124,21 @@ Status CacheRowRequest::Prepare() {
return Status::OK();
}

BatchFetchRequest::BatchFetchRequest(connection_id_type connection_id, const std::vector<row_id_type> &row_id,
bool local_bypass)
: BaseRequest(RequestType::kBatchFetchRows), support_local_bypass_(local_bypass), row_id_(row_id) {
rq_.set_connection_id(connection_id);
CacheRowRequest::CacheRowRequest(const CacheClient *cc)
: BaseRequest(RequestType::kCacheRow),
support_local_bypass_(cc->local_bypass_),
addr_(-1),
sz_(0),
row_id_from_server_(-1) {
rq_.set_connection_id(cc->server_connection_id_);
rq_.set_client_id(cc->client_id_);
rq_.add_buf_data(cc->cookie_);
}

BatchFetchRequest::BatchFetchRequest(const CacheClient *cc, const std::vector<row_id_type> &row_id)
: BaseRequest(RequestType::kBatchFetchRows), support_local_bypass_(cc->local_bypass_), row_id_(row_id) {
rq_.set_connection_id(cc->server_connection_id_);
rq_.set_client_id(cc->client_id_);
rq_.set_flag(support_local_bypass_ ? kLocalClientSupport : 0);
// Convert the row id into a flatbuffer
flatbuffers::FlatBufferBuilder fbb;
@@ -186,9 +203,9 @@ Status BatchFetchRequest::RestoreRows(TensorTable *out, const void *baseAddr, in
return Status::OK();
}

CreateCacheRequest::CreateCacheRequest(const CacheClientInfo &cinfo, uint64_t cache_mem_sz,
CreateCacheRequest::CreateCacheRequest(CacheClient *cc, const CacheClientInfo &cinfo, uint64_t cache_mem_sz,
CreateCacheRequest::CreateCacheFlag flag)
: BaseRequest(RequestType::kCreateCache), cache_mem_sz_(cache_mem_sz), flag_(flag) {
: BaseRequest(RequestType::kCreateCache), cache_mem_sz_(cache_mem_sz), flag_(flag), cc_(cc) {
// Type has been set already in the base constructor. So we need to fill in the connection info.
// On successful return, we will get the connection id
rq_.mutable_connection_info()->operator=(cinfo);
@@ -209,6 +226,41 @@ Status CreateCacheRequest::Prepare() {
}
}

Status CreateCacheRequest::PostReply() {
auto p = flatbuffers::GetRoot<CreateCacheReplyMsg>(reply_.result().data());
cc_->server_connection_id_ = p->connection_id();
cc_->cookie_ = p->cookie()->str();
cc_->client_id_ = p->client_id();
// Next is a set of cpu id that we should re-adjust ourselves for better affinity.
auto sz = p->cpu_id()->size();
cc_->cpu_list_.reserve(sz);
#if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID)
std::string c_list;
cpu_set_t cpu_set;
CPU_ZERO(&cpu_set);
#endif
for (auto i = 0; i < sz; ++i) {
auto cpu_id = p->cpu_id()->Get(i);
cc_->cpu_list_.push_back(cpu_id);
#if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID)
c_list += std::to_string(cpu_id) + " ";
CPU_SET(cpu_id, &cpu_set);
#endif
}

#if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID)
if (sz > 0) {
auto err = sched_setaffinity(getpid(), sizeof(cpu_set), &cpu_set);
if (err == -1) {
RETURN_STATUS_UNEXPECTED("Unable to set affinity. Errno = " + std::to_string(errno));
}
MS_LOG(WARNING) << "Changing cpu affinity to the following list of cpu id: " + c_list;
}
#endif

return Status::OK();
}

Status CacheSchemaRequest::SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map) {
try {
flatbuffers::FlatBufferBuilder fbb;
@@ -245,6 +297,7 @@ Status GetStatRequest::PostReply() {
stat_.num_disk_cached = msg->num_disk_cached();
stat_.num_mem_cached = msg->num_mem_cached();
stat_.avg_cache_sz = msg->avg_cache_sz();
stat_.num_numa_hit = msg->num_numa_hit();
stat_.max_row_id = msg->max_row_id();
stat_.min_row_id = msg->min_row_id();
stat_.cache_service_state = msg->state();
@@ -255,14 +308,15 @@ Status ListSessionsRequest::PostReply() {
auto *msg = flatbuffers::GetRoot<ListSessionsMsg>(reply_.result().data());
auto session_vector = msg->sessions();
for (auto i = 0; i < session_vector->size(); ++i) {
SessionCacheInfo current_info;
CacheServiceStat stats;
SessionCacheInfo current_info{};
CacheServiceStat stats{};
auto current_session_info = session_vector->Get(i);
current_info.session_id = current_session_info->session_id();
current_info.connection_id = current_session_info->connection_id();
stats.num_mem_cached = current_session_info->stats()->num_mem_cached();
stats.num_disk_cached = current_session_info->stats()->num_disk_cached();
stats.avg_cache_sz = current_session_info->stats()->avg_cache_sz();
stats.num_numa_hit = current_session_info->stats()->num_numa_hit();
stats.min_row_id = current_session_info->stats()->min_row_id();
stats.max_row_id = current_session_info->stats()->max_row_id();
stats.cache_service_state = current_session_info->stats()->state();


+ 32
- 18
mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h View File

@@ -41,6 +41,7 @@ struct CacheServiceStat {
int64_t num_mem_cached;
int64_t num_disk_cached;
int64_t avg_cache_sz;
int64_t num_numa_hit;
row_id_type min_row_id;
row_id_type max_row_id;
int8_t cache_service_state;
@@ -75,6 +76,8 @@ class BaseRequest {
kHeartBeat = 14,
kToggleWriteMode = 15,
kListSessions = 16,
kConnectReset = 17,
kInternalFetchRow = 18,
// Add new request before it.
kRequestUnknown = 32767
};
@@ -84,10 +87,15 @@ class BaseRequest {
friend class CacheClientGreeter;
friend class CacheClientRequestTag;
friend class CacheClient;
friend class CacheService;

/// \brief Base class of a cache server request
/// \param type Type of the request
explicit BaseRequest(RequestType type) : type_(type) { rq_.set_type(static_cast<int16_t>(type_)); }
explicit BaseRequest(RequestType type) : type_(type) {
rq_.set_type(static_cast<int16_t>(type_));
rq_.set_client_id(-1);
rq_.set_flag(0);
}
virtual ~BaseRequest() = default;

/// \brief A print method for debugging
@@ -138,15 +146,7 @@ class CacheRowRequest : public BaseRequest {
public:
friend class CacheServer;
friend class CacheClient;
explicit CacheRowRequest(connection_id_type connection_id, const std::string &cookie, bool local_bypass)
: BaseRequest(RequestType::kCacheRow),
support_local_bypass_(local_bypass),
addr_(-1),
sz_(0),
row_id_from_server_(-1) {
rq_.set_connection_id(connection_id);
rq_.add_buf_data(cookie);
}
explicit CacheRowRequest(const CacheClient *cc);
~CacheRowRequest() override = default;

/// \brief Serialize a TensorRow for streaming to the cache server
@@ -193,7 +193,7 @@ class BatchFetchRequest : public BaseRequest {
public:
friend class CacheServer;
friend class CacheService;
BatchFetchRequest(connection_id_type connection_id, const std::vector<row_id_type> &row_id, bool local_bypass);
BatchFetchRequest(const CacheClient *cc, const std::vector<row_id_type> &row_id);
~BatchFetchRequest() override = default;
Status RestoreRows(TensorTable *out, const void *baseAddr, int64_t *out_addr);

@@ -212,21 +212,18 @@ class CreateCacheRequest : public BaseRequest {
/// \param connection_id
/// \param cache_mem_sz Maximum memory assigned for this connection. 0 means unlimited
/// \param flag Attributes of the cache.
explicit CreateCacheRequest(const CacheClientInfo &cinfo, uint64_t cache_mem_sz,
explicit CreateCacheRequest(CacheClient *cc, const CacheClientInfo &cinfo, uint64_t cache_mem_sz,
CreateCacheFlag flag = CreateCacheFlag::kNone);
~CreateCacheRequest() override = default;
void ParseResult(connection_id_type *id, std::string *out) {
auto p = flatbuffers::GetRoot<CreateCacheReplyMsg>(reply_.result().data());
*id = p->connection_id();
*out = p->cookie()->str();
}

/// Overload the base class Prepare
/// Overload the base class Prepare/PostReply
Status Prepare() override;
Status PostReply() override;

private:
uint64_t cache_mem_sz_;
CreateCacheFlag flag_;
CacheClient *cc_;
};

/// \brief Request to get all the keys not present at the server.
@@ -396,6 +393,23 @@ class ToggleWriteModeRequest : public BaseRequest {
}
~ToggleWriteModeRequest() override = default;
};

class ConnectResetRequest : public BaseRequest {
public:
friend class CacheServer;
explicit ConnectResetRequest(connection_id_type connection_id, int32_t client_id)
: BaseRequest(RequestType::kConnectReset) {
rq_.set_connection_id(connection_id);
rq_.set_client_id(client_id);
}
~ConnectResetRequest() override = default;

/// Override the base class function
Status Prepare() override {
CHECK_FAIL_RETURN_UNEXPECTED(rq_.client_id() != -1, "Invalid client id");
return Status::OK();
}
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVICE_H_

+ 266
- 124
mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc View File

@@ -17,6 +17,7 @@
#include <algorithm>
#include <functional>
#include <limits>
#include <vector>
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/engine/cache/cache_ipc.h"
#include "minddata/dataset/engine/cache/cache_service.h"
@@ -43,36 +44,57 @@ Status CacheServer::DoServiceStart() {
MS_LOG(INFO) << "CacheServer will use disk folder: " << top_;
}
RETURN_IF_NOT_OK(vg_.ServiceStart());
// There will be num_workers_ threads working on the grpc queue and
// the same number of threads working on the CacheServerRequest queue.
RETURN_IF_NOT_OK(hw_info_->GetNumaNodeInfo());
auto num_numa_nodes = GetNumaNodeCount();
// If we link with numa library. Set default memory policy.
// If we don't pin thread to cpu, then use up all memory controllers to maximize
// memory bandwidth.
RETURN_IF_NOT_OK(
CacheServerHW::SetDefaultMemoryPolicy(numa_affinity_ ? CachePoolPolicy::kLocal : CachePoolPolicy::kInterleave));
auto my_node = hw_info_->GetMyNode();
MS_LOG(DEBUG) << "Cache server is running on numa node " << my_node;
// Bump up num_workers_ to at least the number of numa nodes
num_workers_ = std::max(num_numa_nodes, num_workers_);
// But also it shouldn't be too many more than the hardware concurrency
auto num_cpus = hw_info_->GetCpuCount();
num_workers_ = std::min(2 * num_cpus, num_workers_);
// Round up num_workers to a multiple of numa nodes.
auto remainder = num_workers_ % num_numa_nodes;
if (remainder > 0) num_workers_ += (num_numa_nodes - remainder);
MS_LOG(INFO) << "Re-adjusting the number of workers to " << num_workers_;
// There will be some threads working on the grpc queue and
// some number of threads working on the CacheServerRequest queue.
// Like a connector object we will set up the same number of queues but
// we do not need to preserve any order. We will set the capacity of
// each queue to be 128 since we are just pushing memory pointers which
// each queue to be 64 since we are just pushing memory pointers which
// is only 8 byte each.
const int32_t que_capacity = 128;
const int32_t kQueCapacity = 64;
// This is the request queue from the client
cache_q_ = std::make_shared<QueueList<CacheServerRequest *>>();
cache_q_->Init(num_workers_, que_capacity);
cache_q_->Init(num_workers_, kQueCapacity);
// We will match the number of grpc workers with the number of server workers.
// But technically they don't have to be the same.
num_grpc_workers_ = num_workers_;
MS_LOG(DEBUG) << "Number of gprc workers is set to " << num_grpc_workers_;
// For the grpc completion queue to work, we need to allocate some
// tags which in our case are instances of CacheServerQuest.
// They got recycled and we will allocate them in advance and push
// them into some free list. We need more (two or three times) the
// size of the cache_q. While each worker is working on a CacheSerRequest,
// we need some extra running injecting in the the qrpc completion queue.
const int32_t multiplier = 3;
const int32_t free_list_capacity = multiplier * (que_capacity + 1);
const int32_t kMultiplier = 2;
int ratio = num_workers_ / num_grpc_workers_;
if (num_workers_ % num_grpc_workers_) ++ratio;
const int32_t free_list_capacity = kMultiplier * (kQueCapacity + 1) * ratio;
free_list_ = std::make_shared<QueueList<CacheServerRequest *>>();
free_list_->Init(num_workers_, free_list_capacity);
// We need to have a reference to the services memory pool in case
// the Services goes out of scope earlier than us since it is a singleton
mp_ = Services::GetInstance().GetServiceMemPool();
Allocator<CacheServerRequest> alloc(mp_);
tag_.reserve(num_workers_);
// Now we populate all free list.
for (auto m = 0; m < num_workers_; ++m) {
// Ideally we allocate all the free list in one malloc. But it turns out it exceeds the
// Arena size. So we will we will allocate one segment at a time.
auto my_tag = std::make_unique<MemGuard<CacheServerRequest, Allocator<CacheServerRequest>>>(alloc);
free_list_->Init(num_grpc_workers_, free_list_capacity);
tag_.reserve(num_grpc_workers_);
// Now we populate all free list. Round robin the free list among the numa nodes.
for (auto m = 0; m < num_grpc_workers_; ++m) {
NumaAllocator<CacheServerRequest> alloc(m % num_numa_nodes, CachePoolPolicy::kPreferred);
// Ideally we allocate all the free list in one malloc. But we will allocate one segment
// at a time so that we can change the numa policy easily per grpc worker.
auto my_tag = std::make_unique<MemGuard<CacheServerRequest, NumaAllocator<CacheServerRequest>>>(alloc);
// Allocate the tag and assign it the current queue
RETURN_IF_NOT_OK(my_tag->allocate(free_list_capacity, m));
for (int i = 0; i < free_list_capacity; ++i) {
@@ -82,11 +104,6 @@ Status CacheServer::DoServiceStart() {
}
RETURN_IF_NOT_OK(cache_q_->Register(&vg_));
RETURN_IF_NOT_OK(free_list_->Register(&vg_));
// Spawn a few threads to serve the real request.
auto f = std::bind(&CacheServer::ServerRequest, this, std::placeholders::_1);
for (auto i = 0; i < num_workers_; ++i) {
RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Cache service worker", std::bind(f, i)));
}
// Start the comm layer
try {
comm_layer_ = std::make_shared<CacheServerGreeterImpl>(port_, shared_memory_sz_in_gb_);
@@ -94,10 +111,29 @@ Status CacheServer::DoServiceStart() {
} catch (const std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
}
// Spawn a few threads to serve the real request.
auto f = std::bind(&CacheServer::ServerRequest, this, std::placeholders::_1);
for (auto i = 0; i < num_workers_; ++i) {
Task *pTask;
RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Cache service worker", std::bind(f, i), &pTask));
// Save a copy of the pointer to the underlying Task object. We may dynamically change their affinity if needed.
numa_tasks_.emplace(i, pTask);
// Spread out all the threads to all the numa nodes if needed
if (IsNumaAffinityOn()) {
auto numa_id = i % num_numa_nodes;
RETURN_IF_NOT_OK(SetAffinity(*pTask, numa_id));
}
}
// Finally loop forever to handle the request.
auto r = std::bind(&CacheServer::RpcRequest, this, std::placeholders::_1);
for (auto i = 0; i < num_workers_; ++i) {
RETURN_IF_NOT_OK(vg_.CreateAsyncTask("rpc worker", std::bind(r, i)));
for (auto i = 0; i < num_grpc_workers_; ++i) {
Task *pTask;
RETURN_IF_NOT_OK(vg_.CreateAsyncTask("rpc worker", std::bind(r, i), &pTask));
// All these grpc workers will be allocated to the same node which is where we allocate all those free tag
// memory.
if (IsNumaAffinityOn()) {
RETURN_IF_NOT_OK(SetAffinity(*pTask, i % num_numa_nodes));
}
}
return Status::OK();
}
@@ -108,8 +144,6 @@ Status CacheServer::DoServiceStop() {
// First stop all the threads.
RETURN_IF_NOT_OK(vg_.ServiceStop());
// Clean up all the caches if any.
// Dump a message how much memory we have consumed in total.
MS_LOG(INFO) << "Memory usage for the current server: " << GetMemoryUsage() << " bytes.";
UniqueLock lck(&rwLock_);
auto it = all_caches_.begin();
while (it != all_caches_.end()) {
@@ -134,13 +168,14 @@ CacheService *CacheServer::GetService(connection_id_type id) const {
Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) {
CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing connection info");
std::string cookie;
int32_t client_id;
auto session_id = rq->connection_info().session_id();
auto crc = rq->connection_info().crc();

// Before allowing the creation, make sure the session had already been created by the user
// Our intention is to add this cache to the active sessions list so leave the list locked during
// this entire function.
UniqueLock lock(&sessions_lock_);
UniqueLock sess_lck(&sessions_lock_);
auto session_it = active_sessions_.find(session_id);
if (session_it == active_sessions_.end()) {
RETURN_STATUS_UNEXPECTED("A cache creation has been requested but the session was not found!");
@@ -163,6 +198,7 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) {
}
flatbuffers::FlatBufferBuilder fbb;
flatbuffers::Offset<flatbuffers::String> off_cookie;
flatbuffers::Offset<flatbuffers::Vector<cpu_id_t>> off_cpu_list;
// Before creating the cache, first check if this is a request for a shared usage of an existing cache
// If two CreateService come in with identical connection_id, we need to serialize the create.
// The first create will be successful and be given a special cookie.
@@ -171,32 +207,74 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) {
if (global_shutdown_) {
return Status::OK();
}
// We would like to protect ourselves from over allocating too much. We will go over existing cache
// and calculate how much we have consumed so far.
auto end = all_caches_.end();
auto it = all_caches_.find(connection_id);
auto it = all_caches_.begin();
bool duplicate = false;
auto avail_mem = CacheServerHW::GetTotalSystemMemory() * memory_cap_ratio_;
int64_t max_avail = avail_mem;
while (it != end) {
if (it->first == connection_id) {
duplicate = true;
break;
} else {
auto &cs = it->second;
CacheService::ServiceStat stat;
RETURN_IF_NOT_OK(cs->GetStat(&stat));
int64_t mem_consumed = stat.stat_.num_mem_cached * stat.stat_.average_cache_sz;
max_avail -= mem_consumed;
if (max_avail <= 0) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
}
}
++it;
}
if (it == end) {
// If we have some cache using some memory already, make a reasonable decision if we should return
// out of memory.
if (max_avail < avail_mem) {
int64_t req_mem = cache_mem_sz * 1048576L; // It is in MB unit.
if (req_mem > max_avail) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
} else if (req_mem == 0) {
// This cache request is specifying unlimited memory up to the memory cap. If we have consumed more than
// 85% of our limit, fail this request.
if (static_cast<float>(max_avail) / static_cast<float>(avail_mem) <= 0.15) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
}
}
}
std::unique_ptr<CacheService> cs;
try {
cs = std::make_unique<CacheService>(cache_mem_sz, spill ? top_ : "", generate_id);
RETURN_IF_NOT_OK(cs->ServiceStart());
cookie = cs->cookie();
client_id = cs->num_clients_.fetch_add(1);
all_caches_.emplace(connection_id, std::move(cs));
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory);
}
// Add the cache into the active session tracking.
// We have already validated that the session exists and that this is a new cache created.
session_it->second.insert(connection_id);

} else {
duplicate = true;
client_id = it->second->num_clients_.fetch_add(1);
MS_LOG(INFO) << "Duplicate request for " + std::to_string(connection_id) + " to create cache service";
}

// Shuffle the worker threads. But we need to release the locks or we will deadlock when calling
// the following function
lck.Unlock();
sess_lck.Unlock();
auto numa_id = client_id % GetNumaNodeCount();
std::vector<cpu_id_t> cpu_list = hw_info_->GetCpuList(numa_id);
// Send back the data
off_cookie = fbb.CreateString(cookie);
off_cpu_list = fbb.CreateVector(cpu_list);
CreateCacheReplyMsgBuilder bld(fbb);
bld.add_connection_id(connection_id);
bld.add_cookie(off_cookie);
bld.add_client_id(client_id);
// The last thing we send back is a set of cpu id that we suggest the client should bind itself to
bld.add_cpu_id(off_cpu_list);
auto off = bld.Finish();
fbb.Finish(off);
reply->set_result(fbb.GetBufferPointer(), fbb.GetSize());
@@ -220,26 +298,8 @@ Status CacheServer::DestroyCache(CacheRequest *rq) {
MS_LOG(INFO) << "Duplicate request for " + std::to_string(id) + " to create cache service";
}
}

// Now that this cache is removed, we need to also remove it's connection id from active session tracking
auto session_id = GetSessionID(id);
UniqueLock sess_lck(&sessions_lock_);

auto it = active_sessions_.find(session_id);
if (it == active_sessions_.end()) {
// The session was not found in the active sessions
RETURN_STATUS_UNEXPECTED("A destroy cache request has been completed but it had a stale session id!");
}

auto connection_it = it->second.find(id);
if (connection_it == it->second.end()) {
RETURN_STATUS_UNEXPECTED("A destroy cache request could not find the connection in the activate sessions!");
}

// remove that connection id from the set
it->second.erase(connection_it);
MS_LOG(INFO) << "Destroyed cache " << id << " and removed from active session " << session_id;

// We aren't touching the session list even though we may be dropping the last remaining cache of a session.
// Leave that to be done by the drop session command.
return Status::OK();
}

@@ -266,6 +326,7 @@ Status CacheServer::CacheRow(CacheRequest *rq, CacheReply *reply) {
buffers.push_back(rq->buf_data(i).data());
}
row_id_type id = -1;
// We will allocate the memory the same numa node this thread is bound to.
RETURN_IF_NOT_OK(cs->CacheRow(buffers, &id));
reply->set_result(std::to_string(id));
} else {
@@ -301,6 +362,7 @@ Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) {
if (!cs->HasBuildPhase() || cookie == cs->cookie()) {
row_id_type id = -1;
ReadableSlice src(p, sz);
// We will allocate the memory the same numa node this thread is bound to.
rc = cs->FastCacheRow(src, &id);
reply->set_result(std::to_string(id));
} else {
@@ -330,9 +392,19 @@ Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) {
for (auto i = 0; i < sz; ++i) {
row_id.push_back(p->row_id()->Get(i));
}
int64_t mem_sz = 0;
std::vector<key_size_pair> v;
RETURN_IF_NOT_OK(cs->PreBatchFetch(row_id, &v, &mem_sz));
std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb = std::make_shared<flatbuffers::FlatBufferBuilder>();
RETURN_IF_NOT_OK(cs->PreBatchFetch(connection_id, row_id, fbb));
auto locator = flatbuffers::GetRoot<BatchDataLocatorMsg>(fbb->GetBufferPointer());
int64_t mem_sz = sizeof(int64_t) * (sz + 1);
for (auto i = 0; i < sz; ++i) {
auto row_sz = locator->rows()->Get(i)->size();
// row_sz is the size of the cached data. Later we will spawn multiple threads
// each of which will copy the data into either shared memory or protobuf concurrently but
// to different region.
// To avoid false sharing, we will bump up row_sz to be a multiple of 4k, i.e. 4096 bytes
row_sz = round_up_4K(row_sz);
mem_sz += row_sz;
}
auto client_flag = rq->flag();
bool local_client = BitTest(client_flag, kLocalClientSupport);
// For large amount data to be sent back, we will use shared memory provided it is a local
@@ -346,7 +418,11 @@ Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) {
void *q = nullptr;
RETURN_IF_NOT_OK(shared_pool->Allocate(mem_sz, &q));
WritableSlice dest(q, mem_sz);
RETURN_IF_NOT_OK(cs->BatchFetch(row_id, v, &dest));
Status rc = cs->BatchFetch(fbb, &dest);
if (rc.IsError()) {
shared_pool->Deallocate(q);
return rc;
}
// We can't return the absolute address which makes no sense to the client.
// Instead we return the difference.
auto difference = reinterpret_cast<int64_t>(q) - reinterpret_cast<int64_t>(base);
@@ -363,7 +439,7 @@ Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) {
return Status(StatusCode::kOutOfMemory);
}
WritableSlice dest(mem.data(), mem_sz);
RETURN_IF_NOT_OK(cs->BatchFetch(row_id, v, &dest));
RETURN_IF_NOT_OK(cs->BatchFetch(fbb, &dest));
reply->set_result(std::move(mem));
}
}
@@ -386,6 +462,7 @@ Status CacheServer::GetStat(CacheRequest *rq, CacheReply *reply) {
bld.add_num_disk_cached(svc_stat.stat_.num_disk_cached);
bld.add_num_mem_cached(svc_stat.stat_.num_mem_cached);
bld.add_avg_cache_sz(svc_stat.stat_.average_cache_sz);
bld.add_num_numa_hit(svc_stat.stat_.num_numa_hit);
bld.add_max_row_id(svc_stat.stat_.max_key);
bld.add_min_row_id(svc_stat.stat_.min_key);
bld.add_state(svc_stat.state_);
@@ -506,30 +583,27 @@ Status CacheServer::ToggleWriteMode(CacheRequest *rq) {
}

Status CacheServer::ListSessions(CacheReply *reply) {
SharedLock lck(&sessions_lock_);
SharedLock sess_lck(&sessions_lock_);
SharedLock lck(&rwLock_);
flatbuffers::FlatBufferBuilder fbb;
std::vector<flatbuffers::Offset<ListSessionMsg>> session_msgs_vector;
for (auto it = active_sessions_.begin(); it != active_sessions_.end(); it++) {
session_id_type current_session_id = it->first;
// Loop over each cache inside this session
if (!it->second.empty()) {
for (auto current_conn_id : it->second) {
CacheService *cs = GetService(current_conn_id);
if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(current_conn_id) + " not found during list sessions.";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
CacheService::ServiceStat svc_stat;
RETURN_IF_NOT_OK(cs->GetStat(&svc_stat));
auto current_stats = CreateServiceStatMsg(fbb, svc_stat.stat_.num_mem_cached, svc_stat.stat_.num_disk_cached,
svc_stat.stat_.average_cache_sz, svc_stat.stat_.min_key,
svc_stat.stat_.max_key, svc_stat.state_);
auto current_session_info = CreateListSessionMsg(fbb, current_session_id, current_conn_id, current_stats);
session_msgs_vector.push_back(current_session_info);
}
for (auto const &current_session_id : active_sessions_) {
bool found = false;
for (auto const &it : all_caches_) {
auto current_conn_id = it.first;
if (GetSessionID(current_conn_id) == current_session_id) {
found = true;
auto &cs = it.second;
CacheService::ServiceStat svc_stat;
RETURN_IF_NOT_OK(cs->GetStat(&svc_stat));
auto current_stats = CreateServiceStatMsg(fbb, svc_stat.stat_.num_mem_cached, svc_stat.stat_.num_disk_cached,
svc_stat.stat_.average_cache_sz, svc_stat.stat_.num_numa_hit,
svc_stat.stat_.min_key, svc_stat.stat_.max_key, svc_stat.state_);
auto current_session_info = CreateListSessionMsg(fbb, current_session_id, current_conn_id, current_stats);
session_msgs_vector.push_back(current_session_info);
}
} else {
}
if (!found) {
// If there is no cache created yet, assign a connection id of 0 along with empty stats
auto current_stats = CreateServiceStatMsg(fbb, 0, 0, 0, 0, 0, 0);
auto current_session_info = CreateListSessionMsg(fbb, current_session_id, 0, current_stats);
@@ -542,18 +616,35 @@ Status CacheServer::ListSessions(CacheReply *reply) {
auto offset = s_builder.Finish();
fbb.Finish(offset);
reply->set_result(fbb.GetBufferPointer(), fbb.GetSize());
return Status::OK();
}

Status CacheServer::ConnectReset(CacheRequest *rq) {
auto connection_id = rq->connection_id();
// Hold the shared lock to prevent the cache from being dropped.
SharedLock lck(&rwLock_);
CacheService *cs = GetService(connection_id);
if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
auto client_id = rq->client_id();
MS_LOG(WARNING) << "Client id " << client_id << " with connection id " << connection_id << " disconnects";
cs->num_clients_--;
}
return Status::OK();
}

/// \brief This is the main loop the cache server thread(s) are running.
/// Each thread will pop a request and send the result back to the client using grpc
/// \return
Status CacheServer::ServerRequest(int32_t worker_id) {
Status CacheServer::ServerRequest(worker_id_t worker_id) {
TaskManager::FindMe()->Post();
MS_LOG(DEBUG) << "Worker id " << worker_id << " is running on node " << hw_info_->GetMyNode();
auto &my_que = cache_q_->operator[](worker_id);
// Loop forever until we are interrupted or shutdown.
while (!global_shutdown_) {
bool internal_request = false;
CacheServerRequest *cache_req = nullptr;
RETURN_IF_NOT_OK(my_que->PopFront(&cache_req));
auto &rq = cache_req->rq_;
@@ -571,8 +662,17 @@ Status CacheServer::ServerRequest(int32_t worker_id) {
}
break;
}
case BaseRequest::RequestType::kBatchFetchRows: {
cache_req->rc_ = BatchFetchRows(&rq, &reply);
case BaseRequest::RequestType::kInternalFetchRow: {
internal_request = true;
auto connection_id = rq.connection_id();
SharedLock lck(&rwLock_);
CacheService *cs = GetService(connection_id);
if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
cache_req->rc_ = cs->InternalFetchRow(flatbuffers::GetRoot<FetchRowMsg>(rq.buf_data(0).data()));
}
break;
}
case BaseRequest::RequestType::kCreateCache: {
@@ -636,6 +736,10 @@ Status CacheServer::ServerRequest(int32_t worker_id) {
cache_req->rc_ = ListSessions(&reply);
break;
}
case BaseRequest::RequestType::kConnectReset: {
cache_req->rc_ = ConnectReset(&rq);
break;
}
default:
std::string errMsg("Unknown request type : ");
errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
@@ -647,7 +751,13 @@ Status CacheServer::ServerRequest(int32_t worker_id) {
// We will re-tag the request back to the grpc queue. Once it comes back from the client,
// the CacheServerRequest, i.e. the pointer cache_req, will be free
if (!global_shutdown_) {
cache_req->responder_.Finish(reply, grpc::Status::OK, cache_req);
if (!internal_request) {
cache_req->responder_.Finish(reply, grpc::Status::OK, cache_req);
} else {
// This is an internal request and is not tied to rpc. But need to post because there
// is a thread waiting on the completion of this request.
cache_req->wp_.Set();
}
}
}
return Status::OK();
@@ -667,12 +777,20 @@ CacheServer::CacheServer(const std::string &spill_path, int32_t num_workers, int
int32_t shared_meory_sz_in_gb, float memory_cap_ratio)
: top_(spill_path),
num_workers_(num_workers),
num_grpc_workers_(num_workers_),
port_(port),
shared_memory_sz_in_gb_(shared_meory_sz_in_gb),
global_shutdown_(false),
memory_cap_ratio_(memory_cap_ratio),
cur_mem_usage_(0) {
memory_cap_ = CacheServer::GetTotalSystemMemory() * memory_cap_ratio_;
numa_affinity_(true) {
hw_info_ = std::make_shared<CacheServerHW>();
// If we are not linked with numa library (i.e. NUMA_ENABLED is false), turn off cpu
// affinity which can make performance worse.
if (!CacheServerHW::numa_enabled()) {
numa_affinity_ = false;
MS_LOG(WARNING) << "Warning: This build is not compiled with numa support. Install libnuma-devel and use a build "
"that is compiled with numa support for more optimal performance";
}
}

Status CacheServer::Run(int msg_qid) {
@@ -719,51 +837,52 @@ Status CacheServer::ReturnRequestTag(CacheServerRequest *p) {
Status CacheServer::DestroySession(CacheRequest *rq) {
CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing session id");
auto drop_session_id = rq->connection_info().session_id();

UniqueLock lck(&sessions_lock_);

// First validate that this session exists
auto it = active_sessions_.find(drop_session_id);
if (it == active_sessions_.end()) {
RETURN_STATUS_UNEXPECTED("A destroy session command has been requested but the session was not found!");
}

// Grab the locks in the correct order to avoid deadlock.
UniqueLock sess_lck(&sessions_lock_);
UniqueLock lck(&rwLock_);
// Iterate over the set of connection id's for this session that we're dropping and erase each one.
{
UniqueLock rwlck(&rwLock_);
for (auto drop_connection_id : it->second) {
auto cache_drop_it = all_caches_.find(drop_connection_id);
if (cache_drop_it == all_caches_.end()) {
RETURN_STATUS_UNEXPECTED("active session tracking had stale or incorrect cache entry.");
}
all_caches_.erase(cache_drop_it);
MS_LOG(INFO) << "Session destroy: Destroy cache with id " << drop_connection_id;
// **Do not bother to remove the cache connection id from the active session because we will soon remove the
// entire session.
bool found = false;
for (auto it = all_caches_.begin(); it != all_caches_.end();) {
auto connection_id = it->first;
auto session_id = GetSessionID(connection_id);
// We can just call DestroyCache() but we are holding a lock already. Doing so will cause deadlock.
// So we will just manually do it.
if (session_id == drop_session_id) {
found = true;
it = all_caches_.erase(it);
MS_LOG(INFO) << "Destroy cache with id " << connection_id;
} else {
++it;
}
}

// Finally remove the session itself
active_sessions_.erase(it);
MS_LOG(INFO) << "Session destroyed with id " << drop_session_id;

return Status::OK();
auto n = active_sessions_.erase(drop_session_id);
if (n > 0) {
MS_LOG(INFO) << "Session destroyed with id " << drop_session_id;
return Status::OK();
} else {
if (found) {
std::string errMsg =
"A destroy cache request has been completed but it had a stale session id " + std::to_string(drop_session_id);
RETURN_STATUS_UNEXPECTED(errMsg);
} else {
std::string errMsg = "Session id " + std::to_string(drop_session_id) + " not found.";
return Status(StatusCode::kFileNotExist, errMsg);
}
}
}

session_id_type CacheServer::GenerateSessionID() {
UniqueLock lock(&sessions_lock_);
UniqueLock sess_lck(&sessions_lock_);
auto mt = GetRandomDevice();
std::uniform_int_distribution<session_id_type> distribution(0, std::numeric_limits<session_id_type>::max());
session_id_type session_id;
bool duplicate = false;
do {
session_id = distribution(mt);
auto it = active_sessions_.find(session_id);
duplicate = (it != active_sessions_.end());
auto r = active_sessions_.insert(session_id);
duplicate = !r.second;
} while (duplicate);

// Add this session to our tracking of active sessions with initialized empty set of connections.
active_sessions_[session_id] = std::set<connection_id_type>();
return session_id;
}

@@ -789,7 +908,7 @@ Status CacheServer::FreeSharedMemory(CacheRequest *rq) {
return Status::OK();
}

Status CacheServer::RpcRequest(int32_t worker_id) {
Status CacheServer::RpcRequest(worker_id_t worker_id) {
TaskManager::FindMe()->Post();
RETURN_IF_NOT_OK(comm_layer_->HandleRequest(worker_id));
return Status::OK();
@@ -820,12 +939,22 @@ Status CacheServer::GlobalShutdown() {
return Status::OK();
}

int64_t CacheServer::GetTotalSystemMemory() {
auto pages = sysconf(_SC_PHYS_PAGES);
auto page_size = sysconf(_SC_PAGE_SIZE);
auto total = static_cast<int64_t>(pages) * static_cast<int64_t>(page_size);
MS_LOG(INFO) << "Total physical RAM in bytes: " << total;
return total;
worker_id_t CacheServer::GetWorkerByNumaId(numa_id_t numa_id) {
auto num_numa_nodes = GetNumaNodeCount();
MS_ASSERT(numa_id < num_numa_nodes);
auto num_workers_per_node = GetNumWorkers() / num_numa_nodes;
std::mt19937 gen = GetRandomDevice();
std::uniform_int_distribution<worker_id_t> dist(0, num_workers_per_node - 1);
auto n = dist(gen);
worker_id_t worker_id = n * num_numa_nodes + numa_id;
MS_ASSERT(worker_id < GetNumWorkers());
return worker_id;
}

worker_id_t CacheServer::GetRandomWorker() {
std::mt19937 gen = GetRandomDevice();
std::uniform_int_distribution<worker_id_t> dist(0, num_workers_ - 1);
return dist(gen);
}

Status CacheServer::Builder::IpcResourceCleanup() {
@@ -842,6 +971,8 @@ Status CacheServer::Builder::IpcResourceCleanup() {
rc = mem.Attach();
if (rc.IsError()) {
return Status::OK();
} else {
RETURN_IF_NOT_OK(mem.Detach());
}
int32_t num_attached;
RETURN_IF_NOT_OK(mem.GetNumAttached(&num_attached));
@@ -892,5 +1023,16 @@ Status CacheServer::Builder::SanityCheck() {
RETURN_IF_NOT_OK(IpcResourceCleanup());
return Status::OK();
}

CacheServer::Builder::Builder()
: top_("/tmp"),
num_workers_(std::thread::hardware_concurrency() / 2),
port_(50052),
shared_memory_sz_in_gb_(4),
memory_cap_ratio_(0.8) {
if (num_workers_ == 0) {
num_workers_ = 1;
}
}
} // namespace dataset
} // namespace mindspore

+ 53
- 32
mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h View File

@@ -17,23 +17,31 @@
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVER_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVER_H_

#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <algorithm>
#include <atomic>
#include <chrono>
#include <iostream>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <map>
#include <set>
#include <thread>
#include "minddata/dataset/engine/cache/cache_hw.h"
#include "minddata/dataset/engine/cache/cache_numa.h"
#include "minddata/dataset/engine/cache/cache_service.h"
#include "minddata/dataset/engine/cache/cache_grpc_server.h"
#include "minddata/dataset/engine/cache/cache_pool.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/util/allocator.h"
#include "minddata/dataset/util/arena.h"
#include "minddata/dataset/util/cache_pool.h"
#include "minddata/dataset/util/lock.h"
#include "minddata/dataset/util/random.h"
#include "minddata/dataset/util/semaphore.h"
#include "minddata/dataset/util/service.h"
#include "minddata/dataset/util/services.h"
#include "minddata/dataset/util/system_pool.h"
@@ -47,9 +55,10 @@ class CacheServer : public Service {
public:
friend class Services;
using cache_index = std::map<connection_id_type, std::unique_ptr<CacheService>>;

class Builder {
public:
Builder() : top_("/tmp"), num_workers_(32), port_(50052), shared_memory_sz_in_gb_(4), memory_cap_ratio_(0.8) {}
Builder();

~Builder() = default;

@@ -161,26 +170,40 @@ class CacheServer : public Service {
/// \return Status object
static Status ReturnRequestTag(CacheServerRequest *p);

/// \brief This returns the size (in bytes) of the physical RAM on the machine.
/// \return the size (in bytes) of the physical RAM on the machine.
static int64_t GetTotalSystemMemory();
/// Return an instance of the numa control
std::shared_ptr<CacheServerHW> GetHWControl() { return hw_info_; }

/// \brief Internally this is how much we will try to use without exceeding the limit
/// \return Internal cap maximum
int64_t GetAvailableSystemMemory() { return memory_cap_; }
/// \brief Set CPU affinity
Status SetAffinity(const Task &tk, numa_id_t numa_node) { return hw_info_->SetAffinity(tk, numa_node); }

/// \brief Find out the current memory usage
int64_t GetMemoryUsage() { return cur_mem_usage_; }
/// \brief return number of workers
auto GetNumWorkers() const { return num_workers_; }

/// \brief This updates our current memory usage.
enum MemUsageOp : int8_t { kAllocate = 1, kFree = 2 };
void UpdateMemoryUsage(int64_t sz, MemUsageOp op) {
if (op == MemUsageOp::kAllocate) {
cur_mem_usage_ += sz;
} else {
cur_mem_usage_ -= sz;
}
}
/// \brief return number of grpc workers
auto GetNumGrpcWorkers() const { return num_grpc_workers_; }

/// \brief return number of numa nodes
auto GetNumaNodeCount() const { return hw_info_->GetNumaNodeCount(); }

/// \brief Assign a worker by a numa id
/// \return worker id
worker_id_t GetWorkerByNumaId(numa_id_t node_id);

/// \brief Randomly pick a worker
/// \return worker id
worker_id_t GetRandomWorker();

/// \brief Check if we bind threads to numa cores
bool IsNumaAffinityOn() const { return numa_affinity_; }

/// \brief Internal function to do row batch fetch
/// \param rq Request
/// \param reply Reply
/// \return Status object
Status BatchFetchRows(CacheRequest *rq, CacheReply *reply);

/// \brief Return the memory cap ratio
float GetMemoryCapRatio() const { return memory_cap_ratio_; }

private:
static std::once_flag init_instance_flag_;
@@ -189,20 +212,21 @@ class CacheServer : public Service {
mutable RWLock sessions_lock_;
std::string top_;
cache_index all_caches_;
std::map<session_id_type, std::set<connection_id_type>> active_sessions_;
std::set<session_id_type> active_sessions_;
std::shared_ptr<QueueList<CacheServerRequest *>> cache_q_;
std::shared_ptr<QueueList<CacheServerRequest *>> free_list_;
std::vector<std::unique_ptr<MemGuard<CacheServerRequest, Allocator<CacheServerRequest>>>> tag_;
std::vector<std::unique_ptr<MemGuard<CacheServerRequest, NumaAllocator<CacheServerRequest>>>> tag_;
std::shared_ptr<CacheServerGreeterImpl> comm_layer_;
std::shared_ptr<MemoryPool> mp_;
TaskGroup vg_;
int32_t num_workers_;
int32_t num_grpc_workers_;
int32_t port_;
int32_t shared_memory_sz_in_gb_;
std::atomic<bool> global_shutdown_;
float memory_cap_ratio_;
int64_t memory_cap_;
std::atomic<int64_t> cur_mem_usage_;
std::shared_ptr<CacheServerHW> hw_info_;
std::map<worker_id_t, Task *> numa_tasks_;
bool numa_affinity_;

/// \brief Constructor
/// \param spill_path Top directory for spilling buffers to.
@@ -226,11 +250,11 @@ class CacheServer : public Service {
Status DestroyCache(CacheRequest *rq);

/// \brief Entry point for all internal server threads.
Status ServerRequest(int32_t worker_id);
Status ServerRequest(worker_id_t worker_id);

/// \brief Entry point for all grpc threads.
/// \return
Status RpcRequest(int32_t worker_id);
Status RpcRequest(worker_id_t worker_id);

Status DestroySession(CacheRequest *rq);

@@ -266,12 +290,6 @@ class CacheServer : public Service {
Status FastCacheRow(CacheRequest *rq, CacheReply *reply);
Status CacheRow(CacheRequest *rq, CacheReply *reply);

/// \brief Internal function to do row batch fetch
/// \param rq Request
/// \param reply Reply
/// \return Status object
Status BatchFetchRows(CacheRequest *rq, CacheReply *reply);

/// \brief Internal function to get statistics
/// \param rq
/// \param reply
@@ -309,6 +327,9 @@ class CacheServer : public Service {
/// \param reply
/// \return Status object
Status ListSessions(CacheReply *reply);

/// \brief Connect request by a pipeline
Status ConnectReset(CacheRequest *rq);
};
} // namespace dataset
} // namespace mindspore


+ 126
- 103
mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc View File

@@ -13,51 +13,45 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <random>
#include "minddata/dataset/engine/cache/cache_service.h"
#include "minddata/dataset/engine/cache/cache_server.h"
#include "minddata/dataset/engine/cache/cache_numa.h"
#include "minddata/dataset/util/random.h"
#include "minddata/dataset/util/slice.h"

namespace mindspore {
namespace dataset {
CacheService::CacheService(uint64_t mem_sz, const std::string &root, bool generate_id)
: root_(root),
cache_mem_sz_(mem_sz),
cache_mem_sz_(mem_sz * 1048576L), // mem_sz is in MB unit
cp_(nullptr),
next_id_(0),
generate_id_(generate_id),
st_(generate_id ? State::kBuildPhase : State::kNone),
cur_mem_usage_(0),
cur_disk_usage_(0) {}
num_clients_(0),
st_(generate_id ? CacheServiceState::kBuildPhase : CacheServiceState::kNone) {}

CacheService::~CacheService() { (void)ServiceStop(); }

bool CacheService::UseArena() {
// If fixed size, use Arena instead of the pool from global context.
return (cache_mem_sz_ > 0);
}

Status CacheService::DoServiceStart() {
std::shared_ptr<MemoryPool> mp_;
CacheServer &cs = CacheServer::GetInstance();
if (UseArena()) {
auto avail_mem = cs.GetAvailableSystemMemory() / 1048576L;
float memory_cap_ratio = cs.GetMemoryCapRatio();
if (cache_mem_sz_ > 0) {
auto avail_mem = CacheServerHW::GetTotalSystemMemory();
if (cache_mem_sz_ > avail_mem) {
// Output a warning that we use more than recommended. If we fail to allocate, we will fail anyway.
MS_LOG(WARNING) << "Requesting cache size " << cache_mem_sz_ << " MB while available system memory " << avail_mem
<< " MB";
MS_LOG(WARNING) << "Requesting cache size " << cache_mem_sz_ << " while available system memory " << avail_mem;
}
// Create a fixed size arena based on the parameter.
std::shared_ptr<Arena> arena;
RETURN_IF_NOT_OK(Arena::CreateArena(&arena, cache_mem_sz_));
mp_ = std::move(arena);
// update the global usage only.
cs.UpdateMemoryUsage(cache_mem_sz_ * 1048576L, CacheServer::MemUsageOp::kAllocate);
} else {
// Unlimited size. Simply use a system pool. Another choice is CircularPool.
mp_ = std::make_shared<SystemPool>();
memory_cap_ratio = static_cast<float>(cache_mem_sz_) / avail_mem;
}
numa_pool_ = std::make_shared<NumaMemoryPool>(cs.GetHWControl(), memory_cap_ratio);
// It is possible we aren't able to allocate the pool for many reasons.
std::vector<numa_id_t> avail_nodes = numa_pool_->GetAvailableNodes();
if (avail_nodes.empty()) {
RETURN_STATUS_UNEXPECTED("Unable to bring up numa memory pool");
}
// Put together a CachePool for backing up the Tensor
cp_ = std::make_shared<CachePool>(CachePool::value_allocator(mp_), UseArena(), root_);
// Put together a CachePool for backing up the Tensor.
cp_ = std::make_shared<CachePool>(numa_pool_, root_);
RETURN_IF_NOT_OK(cp_->ServiceStart());
// Assign a name to this cache. Used for exclusive connection. But we can just use CachePool's name.
cookie_ = cp_->MyName();
@@ -68,26 +62,18 @@ Status CacheService::DoServiceStop() {
if (cp_ != nullptr) {
RETURN_IF_NOT_OK(cp_->ServiceStop());
}
CacheServer &cs = CacheServer::GetInstance();
if (UseArena()) {
cs.UpdateMemoryUsage(cache_mem_sz_ * 1048576L, CacheServer::MemUsageOp::kFree);
} else {
MS_LOG(INFO) << "Memory/disk usage for the current service: " << GetMemoryUsage() << " bytes and " << GetDiskUsage()
<< " bytes.";
cs.UpdateMemoryUsage(GetMemoryUsage(), CacheServer::MemUsageOp::kFree);
}
return Status::OK();
}

Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type *row_id_generated) {
SharedLock rw(&rw_lock_);
RETURN_UNEXPECTED_IF_NULL(row_id_generated);
if (st_ == State::kFetchPhase) {
if (st_ == CacheServiceState::kFetchPhase) {
// For this kind of cache service, once we are done with the build phase into fetch phase, we can't
// allow other to cache more rows.
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
}
if (st_ == State::kNoLocking) {
if (st_ == CacheServiceState::kNoLocking) {
// We ignore write this request once we turn off locking on the B+ tree. So we will just
// return out of memory from now on.
return Status(StatusCode::kOutOfMemory);
@@ -128,26 +114,13 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type
all_data.emplace_back(buf.at(i + 1), msg->data_sz()->Get(i));
total_sz += msg->data_sz()->Get(i);
}
// Now we cache the buffer. If we are using Arena which has a fixed cap, then just do it.
// Otherwise, we check how much (globally) how much we use and may simply spill to disk
// directly.
CacheServer &cs = CacheServer::GetInstance();
bool write_to_disk_directly = UseArena() ? false : (total_sz + cs.GetMemoryUsage()) > cs.GetAvailableSystemMemory();
Status rc = cp_->Insert(*row_id_generated, all_data, write_to_disk_directly);
// Now we cache the buffer.
Status rc = cp_->Insert(*row_id_generated, all_data);
if (rc == Status(StatusCode::kDuplicateKey)) {
MS_LOG(DEBUG) << "Ignoring duplicate key.";
} else {
RETURN_IF_NOT_OK(rc);
}
// All good, then update the memory usage local and global (if not using arena)
if (write_to_disk_directly) {
cur_disk_usage_ += total_sz;
} else {
cur_mem_usage_ += total_sz;
if (!UseArena()) {
cs.UpdateMemoryUsage(total_sz, CacheServer::MemUsageOp::kAllocate);
}
}
return Status::OK();
} catch (const std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
@@ -157,12 +130,12 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type
Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_generated) {
SharedLock rw(&rw_lock_);
RETURN_UNEXPECTED_IF_NULL(row_id_generated);
if (st_ == State::kFetchPhase) {
if (st_ == CacheServiceState::kFetchPhase) {
// For this kind of cache service, once we are done with the build phase into fetch phase, we can't
// allow other to cache more rows.
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
}
if (st_ == State::kNoLocking) {
if (st_ == CacheServiceState::kNoLocking) {
// We ignore write this request once we turn off locking on the B+ tree. So we will just
// return out of memory from now on.
return Status(StatusCode::kOutOfMemory);
@@ -183,27 +156,13 @@ Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_
}
*row_id_generated = msg->row_id();
}
// Now we cache the buffer. If we are using Arena which has a fixed cap, then just do it.
// Otherwise, we check how much (globally) how much we use and may simply spill to disk
// directly.
auto total_sz = src.GetSize();
CacheServer &cs = CacheServer::GetInstance();
bool write_to_disk_directly = UseArena() ? false : (total_sz + cs.GetMemoryUsage()) > cs.GetAvailableSystemMemory();
Status rc = cp_->Insert(*row_id_generated, {src}, write_to_disk_directly);
// Now we cache the buffer.
Status rc = cp_->Insert(*row_id_generated, {src});
if (rc == Status(StatusCode::kDuplicateKey)) {
MS_LOG(DEBUG) << "Ignoring duplicate key.";
} else {
RETURN_IF_NOT_OK(rc);
}
// All good, then update the memory usage local and global (if not using arena)
if (write_to_disk_directly) {
cur_disk_usage_ += total_sz;
} else {
cur_mem_usage_ += total_sz;
if (!UseArena()) {
cs.UpdateMemoryUsage(total_sz, CacheServer::MemUsageOp::kAllocate);
}
}
return Status::OK();
} catch (const std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
@@ -247,52 +206,116 @@ Status CacheService::GetStat(CacheService::ServiceStat *out) {
return Status::OK();
}

Status CacheService::PreBatchFetch(const std::vector<row_id_type> &v, std::vector<key_size_pair> *out,
int64_t *mem_sz) {
Status CacheService::PreBatchFetch(connection_id_type connection_id, const std::vector<row_id_type> &v,
const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb) {
SharedLock rw(&rw_lock_);
RETURN_UNEXPECTED_IF_NULL(out);
RETURN_UNEXPECTED_IF_NULL(mem_sz);
const auto num_elements = v.size();
*mem_sz = (num_elements + 1) * sizeof(int64_t);
(*out).reserve(num_elements);
std::vector<flatbuffers::Offset<DataLocatorMsg>> datalocator_v;
datalocator_v.reserve(v.size());
for (auto row_id : v) {
auto sz = cp_->GetSize(row_id);
if (sz > 0) {
(*out).emplace_back(row_id, sz);
(*mem_sz) += sz;
} else {
// key not found
(*out).emplace_back(-1, 0);
}
flatbuffers::Offset<DataLocatorMsg> offset;
RETURN_IF_NOT_OK(cp_->GetDataLocator(row_id, fbb, &offset));
datalocator_v.push_back(offset);
}
auto offset_v = fbb->CreateVector(datalocator_v);
BatchDataLocatorMsgBuilder bld(*fbb);
bld.add_connection_id(connection_id);
bld.add_rows(offset_v);
auto offset_final = bld.Finish();
fbb->Finish(offset_final);
return Status::OK();
}

Status CacheService::BatchFetch(const std::vector<row_id_type> &v, const std::vector<key_size_pair> &info,
WritableSlice *out) const {
Status CacheService::BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb, WritableSlice *out) const {
RETURN_UNEXPECTED_IF_NULL(out);
SharedLock rw(&rw_lock_);
if (st_ == State::kBuildPhase) {
if (st_ == CacheServiceState::kBuildPhase) {
// For this kind of cache service, we can't fetch yet until we are done with caching all the rows.
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
}
const auto num_elements = v.size();
CacheServer &cs = CacheServer::GetInstance();
int32_t numQ = cs.GetNumGrpcWorkers();
auto rng = GetRandomDevice();
std::uniform_int_distribution<session_id_type> distribution(0, numQ - 1);
int32_t qID = distribution(rng);
std::vector<CacheServerRequest *> cache_rq_list;
auto p = flatbuffers::GetRoot<BatchDataLocatorMsg>(fbb->GetBufferPointer());
const auto num_elements = p->rows()->size();
auto connection_id = p->connection_id();
cache_rq_list.reserve(num_elements);
int64_t data_offset = (num_elements + 1) * sizeof(int64_t);
auto *offset_array = reinterpret_cast<int64_t *>(out->GetMutablePointer());
offset_array[0] = data_offset;
for (auto i = 0; i < num_elements; ++i) {
auto sz = info.at(i).second;
offset_array[i + 1] = offset_array[i] + sz;
auto data_locator = p->rows()->Get(i);
auto node_id = data_locator->node_id();
size_t sz = data_locator->size();
void *source_addr = reinterpret_cast<void *>(data_locator->addr());
auto key = data_locator->key();
// Please read the comment in CacheServer::BatchFetchRows where we allocate
// the buffer big enough so each thread (which we are going to dispatch) will
// not run into false sharing problem. We are going to round up sz to 4k.
auto sz_4k = round_up_4K(sz);
offset_array[i + 1] = offset_array[i] + sz_4k;
if (sz > 0) {
WritableSlice row_data(*out, offset_array[i], sz);
auto key = info.at(i).first;
size_t bytesRead = 0;
RETURN_IF_NOT_OK(cp_->Read(key, &row_data, &bytesRead));
if (bytesRead != sz) {
MS_LOG(ERROR) << "Unexpected length. Read " << bytesRead << ". Expected " << sz << "."
<< " Internal key: " << key << "\n";
RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details.");
}
// Get a request and send to the proper worker (at some numa node) to do the fetch.
worker_id_t worker_id = cs.IsNumaAffinityOn() ? cs.GetWorkerByNumaId(node_id) : cs.GetRandomWorker();
CacheServerRequest *cache_rq;
RETURN_IF_NOT_OK(cs.GetFreeRequestTag(qID++ % numQ, &cache_rq));
cache_rq_list.push_back(cache_rq);
// Set up all the necessarily field.
cache_rq->type_ = BaseRequest::RequestType::kInternalFetchRow;
cache_rq->st_ = CacheServerRequest::STATE::PROCESS;
cache_rq->rq_.set_connection_id(connection_id);
cache_rq->rq_.set_type(static_cast<int16_t>(cache_rq->type_));
auto dest_addr = row_data.GetMutablePointer();
flatbuffers::FlatBufferBuilder fb2;
FetchRowMsgBuilder bld(fb2);
bld.add_key(key);
bld.add_size(sz);
bld.add_source_addr(reinterpret_cast<int64_t>(source_addr));
bld.add_dest_addr(reinterpret_cast<int64_t>(dest_addr));
auto offset = bld.Finish();
fb2.Finish(offset);
cache_rq->rq_.add_buf_data(fb2.GetBufferPointer(), fb2.GetSize());
RETURN_IF_NOT_OK(cs.PushRequest(worker_id, cache_rq));
}
}
// Now wait for all of them to come back. Let go of the shared lock. We shouldn't be holding
// any lock while we can wait for a long time.
rw.Unlock();
Status rc;
for (CacheServerRequest *rq : cache_rq_list) {
RETURN_IF_NOT_OK(rq->Wait());
if (rq->rc_.IsError() && !rq->rc_.IsInterrupted() && rc.IsOk()) {
rc = rq->rc_;
}
RETURN_IF_NOT_OK(cs.ReturnRequestTag(rq));
}
return rc;
}

Status CacheService::InternalFetchRow(const FetchRowMsg *p) {
RETURN_UNEXPECTED_IF_NULL(p);
SharedLock rw(&rw_lock_);
size_t bytesRead = 0;
int64_t key = p->key();
size_t sz = p->size();
void *source_addr = reinterpret_cast<void *>(p->source_addr());
void *dest_addr = reinterpret_cast<void *>(p->dest_addr());
WritableSlice dest(dest_addr, sz);
if (source_addr != nullptr) {
// We are not checking if the row is still present but simply use the information passed in.
// This saves another tree lookup and is faster.
ReadableSlice src(source_addr, sz);
RETURN_IF_NOT_OK(WritableSlice::Copy(&dest, src));
} else {
RETURN_IF_NOT_OK(cp_->Read(key, &dest, &bytesRead));
if (bytesRead != sz) {
std::string errMsg = "Unexpected length. Read " + std::to_string(bytesRead) + ". Expected " + std::to_string(sz) +
"." + " Internal key: " + std::to_string(key);
MS_LOG(ERROR) << errMsg;
RETURN_STATUS_UNEXPECTED(errMsg);
}
}
return Status::OK();
@@ -312,7 +335,7 @@ Status CacheService::CacheSchema(const void *buf, int64_t len) {

Status CacheService::FetchSchema(std::string *out) const {
SharedLock rw(&rw_lock_);
if (st_ == State::kBuildPhase) {
if (st_ == CacheServiceState::kBuildPhase) {
// For this kind of cache service, we can't fetch yet until we are done with caching all the rows.
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
}
@@ -333,7 +356,7 @@ Status CacheService::BuildPhaseDone() {
if (HasBuildPhase()) {
// Exclusive lock to switch phase
UniqueLock rw(&rw_lock_);
st_ = State::kFetchPhase;
st_ = CacheServiceState::kFetchPhase;
cp_->SetLocking(false);
return Status::OK();
} else {
@@ -348,12 +371,12 @@ Status CacheService::ToggleWriteMode(bool on_off) {
} else {
// If we stop accepting write request, we turn off locking for the
// underlying B+ tree. All future write request we will return kOutOfMemory.
if (st_ == State::kNone && !on_off) {
st_ = State::kNoLocking;
if (st_ == CacheServiceState::kNone && !on_off) {
st_ = CacheServiceState::kNoLocking;
cp_->SetLocking(on_off);
MS_LOG(WARNING) << "Locking mode is switched off.";
} else if (st_ == State::kNoLocking && on_off) {
st_ = State::kNone;
} else if (st_ == CacheServiceState::kNoLocking && on_off) {
st_ = CacheServiceState::kNone;
cp_->SetLocking(on_off);
}
}


+ 13
- 27
mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h View File

@@ -29,36 +29,28 @@
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/cache/cache_request.h"
#include "minddata/dataset/engine/cache/cache_pool.h"
#include "minddata/dataset/util/arena.h"
#include "minddata/dataset/util/btree.h"
#include "minddata/dataset/util/cache_pool.h"
#include "minddata/dataset/util/service.h"
#include "minddata/dataset/util/services.h"
#include "minddata/dataset/util/system_pool.h"

namespace mindspore {
namespace dataset {
/// Some typedef used for BatchFetch
using key_size_pair = std::pair<CachePool::key_type, size_t>;
/// \brief A cache service for storing/fetching buffers to in memory cache and may spill to disk the cache service is
/// created to support spilling
class CacheService : public Service {
public:
friend class CacheServer;

enum class State : uint8_t { kNone = 0, kBuildPhase, kFetchPhase, kNoLocking };

/// \brief Constructor
/// \param mem_sz Memory size to be set aside for the in memory cache. 0 means unlimited
/// \param root Spill path. Empty string means no spilling
/// \param generate_id If the cache service should generate row id for buffer that is cached.
/// For non-mappable dataset, this should be set to true.
CacheService(uint64_t mem_sz, const std::string &root, bool generate_id);
~CacheService();

/// \brief For fixed size memory, we will create an Arena.
/// \return false if unlimited memory.
bool UseArena();
~CacheService() override;

Status DoServiceStart() override;
Status DoServiceStop() override;
@@ -77,18 +69,18 @@ class CacheService : public Service {
Status FastCacheRow(const ReadableSlice &src, row_id_type *row_id_generated);

/// \brief This function is used in preparation for batch fetching.
/// It calculates how much memory we should allocate and which row id are present.
/// \param[in/out] Pointer to vector of <CachePool::key_type, size_t>
/// \param[in/out] mem_sz how much memory is required to batch fetch
/// It calculates how much memory we should allocate and which row id are present, etc.
/// All needed results are stored in the flat buffer.
/// \return Status object
Status PreBatchFetch(const std::vector<row_id_type> &v, std::vector<key_size_pair> *, int64_t *mem_sz);
Status PreBatchFetch(connection_id_type connection_id, const std::vector<row_id_type> &v,
const std::shared_ptr<flatbuffers::FlatBufferBuilder> &);

/// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded
/// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row.
/// \param[in] v A vector of row id.
/// \param[out] out A contiguous memory buffer that holds the requested rows.
/// \return Status object
Status BatchFetch(const std::vector<row_id_type> &v, const std::vector<key_size_pair> &, WritableSlice *out) const;
Status BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &, WritableSlice *out) const;

/// \brief Getter function
/// \return Spilling path
@@ -96,7 +88,7 @@ class CacheService : public Service {
/// \brief A structure returned from the cache server for statistics request.
class ServiceStat {
public:
using state_type = std::underlying_type<State>::type;
using state_type = std::underlying_type<CacheServiceState>::type;
ServiceStat() : state_(0) {}
~ServiceStat() = default;
CachePool::CacheStat stat_{};
@@ -134,10 +126,6 @@ class CacheService : public Service {
/// \brief Change from write phase to read phase. Only the creator of this service is allowed to make this call.
/// \return Status object
Status BuildPhaseDone();
/// \brief Find out the current memory usage
int64_t GetMemoryUsage() { return cur_mem_usage_; }
/// \brief Find out the current disk usage
int64_t GetDiskUsage() { return cur_disk_usage_; }
/// \brief For kToggleWriteMode request
Status ToggleWriteMode(bool on_off);

@@ -149,14 +137,10 @@ class CacheService : public Service {
std::atomic<row_id_type> next_id_;
bool generate_id_;
std::string cookie_;
State st_;
std::atomic<int32_t> num_clients_;
CacheServiceState st_;
std::string schema_;
// If we use an Arena, cur_disk_usage is always 0 as we don't know how CachePool manages it.
// Otherwise we track how much is in memory and how much is on disk (if root_ is not empty).
// We use them to control when we should stop caching in memory in the case when there is no
// Arena.
std::atomic<int64_t> cur_mem_usage_;
std::atomic<int64_t> cur_disk_usage_;
std::shared_ptr<NumaMemoryPool> numa_pool_;
// We also cache the result from calling FindKeysMiss because it is expensive. Besides user make
// this request after we hit memory full or disk full. So the result is unlikely to change.
std::mutex get_key_miss_mux_;
@@ -164,6 +148,8 @@ class CacheService : public Service {
/// \brief Private function to generate a row id
/// \return Row id assigned.
row_id_type GetNextRowId() { return next_id_.fetch_add(1); }

Status InternalFetchRow(const FetchRowMsg *p);
};
} // namespace dataset
} // namespace mindspore


+ 23
- 1
mindspore/ccsrc/minddata/dataset/engine/cache/de_tensor.fbs View File

@@ -65,6 +65,7 @@ table ServiceStatMsg {
num_mem_cached:int64;
num_disk_cached:int64;
avg_cache_sz:int64;
num_numa_hit:int64;
min_row_id:int64;
max_row_id:int64;
state:int8;
@@ -89,8 +90,10 @@ table CreateCacheRequestMsg {

/// Return result of CreateCacheRequest
table CreateCacheReplyMsg {
connection_id:int64;
client_id:int32;
connection_id:uint64;
cookie:string;
cpu_id:[int32];
}

table ListSessionMsg {
@@ -102,3 +105,22 @@ table ListSessionMsg {
table ListSessionsMsg {
sessions:[ListSessionMsg];
}

table DataLocatorMsg {
key:int64;
node_id:int32;
addr:int64;
size:int64;
}

table BatchDataLocatorMsg {
connection_id:uint64;
rows:[DataLocatorMsg];
}

table FetchRowMsg {
key:int64;
source_addr:int64;
dest_addr:int64;
size:int64;
}

+ 32
- 0
mindspore/ccsrc/minddata/dataset/engine/cache/perf/CMakeLists.txt View File

@@ -0,0 +1,32 @@
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)

if (ENABLE_CACHE)
ms_protobuf_generate(CACHE_PERF_PROTO_SRCS CACHE_PERF_PROTO_HDRS cache_perf.proto)

add_executable(cache_perf cache_perf.cc cache_msg.cc cache_perf_run.cc ${CACHE_PERF_PROTO_SRCS})
target_link_libraries(cache_perf
_c_dataengine
_c_mindrecord
mindspore::protobuf
mindspore_gvar
${PYTHON_LIBRARIES}
pthread)

if (USE_GLOG)
target_link_libraries(cache_perf mindspore::glog)
endif ()

add_executable(cache_pipeline cache_pipeline.cc cache_msg.cc cache_pipeline_run.cc ${CACHE_PERF_PROTO_SRCS})
target_link_libraries(cache_pipeline
_c_dataengine
_c_mindrecord
mindspore::protobuf
mindspore_gvar
${PYTHON_LIBRARIES}
pthread)

if (USE_GLOG)
target_link_libraries(cache_pipeline mindspore::glog)
endif ()
endif ()

+ 48
- 0
mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_msg.cc View File

@@ -0,0 +1,48 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "minddata/dataset/engine/cache/perf/cache_msg.h"
#include <sys/types.h>
#include <sys/ipc.h>
#include <sys/msg.h>

namespace mindspore {
namespace dataset {
Status CachePerfMsg::Send(int32_t qID) {
auto err = msgsnd(qID, reinterpret_cast<void *>(&small_msg_), sizeof(small_msg_.body.msg), IPC_NOWAIT);
if (err == -1) {
std::string errMsg = "Failed to call msgsnd. Errno = " + std::to_string(errno);
RETURN_STATUS_UNEXPECTED(errMsg);
}
return Status::OK();
}

Status CachePerfMsg::Receive(int32_t qID) {
// This is a blocking call. Either there is some message or we the queue is removed when
// the destructor is called.
auto err = msgrcv(qID, reinterpret_cast<void *>(&small_msg_), sizeof(small_msg_.body.msg), 0, MSG_NOERROR);
if (err == -1) {
if (errno == EIDRM) {
return Status(StatusCode::kInterrupted);
} else {
std::string errMsg = "Failed to call msgrcv. Errno = " + std::to_string(errno);
RETURN_STATUS_UNEXPECTED(errMsg);
}
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 78
- 0
mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_msg.h View File

@@ -0,0 +1,78 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PERF_MSG_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PERF_MSG_H_

#include <cstdint>
#include <limits>
#include <string>
#include "proto/cache_perf.pb.h"
#include "minddata/dataset/engine/cache/cache_common.h"
#include "minddata/dataset/util/status.h"

namespace mindspore {
namespace dataset {
// All our messages are very small. So we will use the stack version without the need
// to allocate memory.
struct CacheSmallMsg {
int64_t mtype;
union {
char mtext[1];
struct {
int32_t type; // the first 4 bytes is the RequestType
int32_t proto_sz;
char proto_buffer[kSharedMessageSize];
} msg;
} body;
};
/// A message queue structure between the parent and the child process
class CachePerfMsg {
public:
enum MessageType : int16_t {
kInterrupt = 0,
kEpochResult = 1,
kEpochStart = 2,
kEpochEnd = 3,
kError = 4,
// Add new message before it.
kUnknownMessage = 32767
};
CachePerfMsg() : small_msg_{1} {
small_msg_.body.msg.type = kUnknownMessage;
small_msg_.body.msg.proto_sz = 0;
small_msg_.body.msg.proto_buffer[0] = 0;
}
~CachePerfMsg() = default;

char *GetMutableBuffer() { return small_msg_.body.msg.proto_buffer; }

Status Send(int32_t qID);

void SetType(MessageType requestType) { small_msg_.body.msg.type = requestType; }
void SetProtoBufSz(size_t sz) { small_msg_.body.msg.proto_sz = sz; }

MessageType GetType() const { return static_cast<MessageType>(small_msg_.body.msg.type); }
size_t GetProtoBufSz() const { return small_msg_.body.msg.proto_sz; }

Status Receive(int32_t qID);

private:
CacheSmallMsg small_msg_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PERF_MSG_H_

+ 39
- 0
mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_perf.cc View File

@@ -0,0 +1,39 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifdef USE_GLOG
#include <glog/logging.h>
#endif
#include <iostream>
#include "minddata/dataset/engine/cache/perf/cache_perf_run.h"
namespace ds = mindspore::dataset;

int main(int argc, char **argv) {
#ifdef USE_GLOG
FLAGS_log_dir = "/tmp";
google::InitGoogleLogging(argv[0]);
#endif
ds::CachePerfRun cachePerfRun;
if (cachePerfRun.ProcessArgs(argc, argv) == 0) {
std::cout << cachePerfRun << std::endl;
ds::Status rc = cachePerfRun.Run();
if (rc.IsError()) {
std::cerr << rc.ToString() << std::endl;
}
return static_cast<int>(rc.get_code());
}
return 0;
}

+ 39
- 0
mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_perf.proto View File

@@ -0,0 +1,39 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

syntax = "proto3";
package mindspore.dataset;
option cc_enable_arenas = true;

message PipelineWorkerEpochSummary {
int32 pipeline = 1;
int32 worker = 2;
int64 min = 3;
int64 max = 4;
int64 avg = 5;
int64 med = 6;
int64 cnt = 7;
int64 elapse = 8;
}

message EpochDone {
int32 pipeline = 1;
}

message ErrorMsg {
int32 rc = 1;
string msg = 2;
}

+ 575
- 0
mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_perf_run.cc View File

@@ -0,0 +1,575 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "minddata/dataset/engine/cache/perf/cache_perf_run.h"
#include <string.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <sys/wait.h>
#include <sys/ipc.h>
#include <sys/msg.h>
#include <unistd.h>
#include <algorithm>
#include <chrono>
#include <iomanip>
#include <sstream>
#include "minddata/dataset/util/random.h"
#include "minddata/dataset/util/services.h"
#include "minddata/dataset/util/sig_handler.h"

namespace mindspore {
namespace dataset {
const char CachePerfRun::kCachePipelineBinary[] = "cache_pipeline";
void CachePerfRun::PrintHelp() {
std::cout << "Options:\n"
" -h,--help: Show this usage message\n"
" -s,--num_rows: Set the sample size, i.e., the number of "
"rows\n"
" -r,--row_size: Set the average row size\n"
" -n,--pipeline: Set the number of parallel pieplines. Default = "
<< kDftNumOfPipelines
<< "\n"
" -e,--epoch: Set the number of epochs. Default = "
<< kDftNumberOfEpochs
<< "\n"
" --shuffle: Set shuffle=True. Default = "
<< std::boolalpha << kDftShuffle
<< "\n"
" -p,--prefetch_size: Set the prefetch size for cache. Default = "
<< kDftPrefetchSize << "\n"
<< " -a,--cache_size: Set cache size. Default = " << kDftCacheSize
<< " (Mb)\n"
" --spill: Set spill to disk to True. Default = "
<< std::boolalpha << kDftSpill << "\n"
<< " -w,--workers: Set the number of parallel workers. Default = " << cfg_.num_parallel_workers()
<< "\n"
" --connection: Set number of TCP/IP connections per pipeline. Default = "
<< kDftNumConnections << "\n"
<< " --port: TCP/IP port of the cache server. Default = " << kCfgDefaultCachePort << "\n"
<< " --hostname: Hostname of the cache server. Default = " << kCfgDefaultCacheHost << "\n";
}

int32_t CachePerfRun::ProcessArgs(int argc, char **argv) {
if (argc == 1) {
PrintHelp();
return -1;
}

const int32_t port_opt = 1000; // there is no short option for port
const int32_t hostname_opt = 1001; // there is no short option for hostname
const int32_t connect_opt = 1002; // there is no short option for connect

int shuffle = 0;
int spill = 0;

const char *const short_opts = ":n:e:p:a:s:r:w:";
const option long_opts[] = {{"pipeline", required_argument, nullptr, 'n'},
{"epoch", required_argument, nullptr, 'e'},
{"prefetch_size", required_argument, nullptr, 'p'},
{"shuffle", no_argument, &shuffle, 1},
{"cache_size", required_argument, nullptr, 'a'},
{"num_rows", required_argument, nullptr, 's'},
{"row_size", required_argument, nullptr, 'r'},
{"workers", required_argument, nullptr, 'w'},
{"port", required_argument, nullptr, port_opt},
{"hostname", required_argument, nullptr, hostname_opt},
{"spill", no_argument, &spill, 1},
{"connection", required_argument, nullptr, connect_opt},
{"help", no_argument, nullptr, 'h'},
{nullptr, no_argument, nullptr, 0}};

std::map<int32_t, int32_t> seen_opts;
int32_t rc = 0;
try {
while (rc == 0) {
int32_t option_indxex;
const auto opt = getopt_long(argc, argv, short_opts, long_opts, &option_indxex);

if (-1 == opt) {
if (optind < argc) {
rc = -1;
std::cerr << "Unknown arguments: ";
while (optind < argc) {
std::cerr << argv[optind++] << " ";
}
std::cerr << std::endl;
}
break;
}

if (opt > 0) {
seen_opts[opt]++;
if (seen_opts[opt] > 1) {
std::string long_name = long_opts[option_indxex].name;
std::cerr << "The " << long_name << " argument was given more than once." << std::endl;
rc = -1;
continue;
}
}

switch (opt) {
case 0: {
if (long_opts[option_indxex].flag == &shuffle) {
shuffle_ = true;
} else if (long_opts[option_indxex].flag == &spill) {
cache_builder_.SetSpill(true);
}
break;
}

case 'n': {
num_pipelines_ = std::stoi(optarg);
break;
}

case 'e': {
num_epoches_ = std::stoi(optarg);
break;
}

case 'p': {
int32_t prefetch_sz = std::stoi(optarg);
cache_builder_.SetPrefetchSize(prefetch_sz);
break;
}

case 'a': {
int32_t cache_sz = std::stoi(optarg);
cache_builder_.SetCacheMemSz(cache_sz);
break;
}

case 's': {
num_rows_ = std::stoi(optarg);
break;
}

case 'r': {
row_size_ = std::stoi(optarg);
break;
}

case 'w': {
cfg_.set_num_parallel_workers(std::stoi(optarg));
break;
}

case connect_opt: {
int32_t connection_sz = std::stoi(optarg);
cache_builder_.SetNumConnections(connection_sz);
break;
}

case port_opt: {
int32_t port = std::stoi(optarg);
cache_builder_.SetPort(port);
break;
}

case hostname_opt: {
std::string hostname = optarg;
cache_builder_.SetHostname(hostname);
break;
}

case 'h': // -h or --help
PrintHelp();
rc = -1;
break;

case ':':
std::cerr << "Missing argument for option " << char(optopt) << std::endl;
rc = -1;
break;

case '?': // Unrecognized option
default:
std::cerr << "Unknown option " << char(optopt) << std::endl;
PrintHelp();
rc = -1;
break;
}
}
} catch (const std::exception &e) {
PrintHelp();
rc = -1;
}

if (rc < 0) {
return rc;
}

// We have all the defaults except sample size and average row size which the user must specify.
auto it = seen_opts.find('s');
if (it == seen_opts.end()) {
std::cerr << "Missing sample size." << std::endl;
return -1;
}

it = seen_opts.find('r');
if (it == seen_opts.end()) {
std::cerr << "Missing average row size." << std::endl;
return -1;
}

if (num_rows_ <= 0) {
std::cerr << "Sample size must be positive." << std::endl;
return -1;
}

if (row_size_ <= 0) {
std::cerr << "Average row size must be positive." << std::endl;
return -1;
}

if (num_pipelines_ <= 0) {
std::cerr << "Number of pipelines must be positive." << std::endl;
return -1;
}

if (num_epoches_ <= 0) {
std::cerr << "Number of epoches must be positive." << std::endl;
return -1;
}

if (num_rows_ < num_pipelines_) {
std::cerr << "Sample size is smaller than the number of pipelines." << std::endl;
return -1;
}

pid_lists_.reserve(num_pipelines_);

return 0;
}

Status CachePerfRun::GetSession() {
CacheClientGreeter comm(cache_builder_.GetHostname(), cache_builder_.GetPort(), 1);
RETURN_IF_NOT_OK(comm.ServiceStart());
auto rq = std::make_shared<GenerateSessionIdRequest>();
RETURN_IF_NOT_OK(comm.HandleRequest(rq));
RETURN_IF_NOT_OK(rq->Wait());
session_ = rq->GetSessionId();
std::cout << "Session: " << session_ << std::endl;
cache_builder_.SetSessionId(session_);
return Status::OK();
}

CachePerfRun::CachePerfRun()
: my_pipeline_(-1),
num_pipelines_(kDftNumOfPipelines),
num_epoches_(kDftNumberOfEpochs),
num_rows_(0),
row_size_(0),
shuffle_(kDftShuffle),
session_(0),
crc_(0),
epoch_sync_cnt_(0) {
cache_builder_.SetSpill(kDftSpill).SetCacheMemSz(kDftCacheSize);
}

CachePerfRun::~CachePerfRun() {
if (session_ != 0) {
Status rc;
CacheClientGreeter comm(cache_builder_.GetHostname(), cache_builder_.GetPort(), 1);
rc = comm.ServiceStart();
if (rc.IsOk()) {
CacheClientInfo cinfo;
cinfo.set_session_id(session_);
auto rq = std::make_shared<DropSessionRequest>(cinfo);
rc = comm.HandleRequest(rq);
if (rc.IsOk()) {
rc = rq->Wait();
if (rc.IsOk()) {
std::cout << "Drop session " << session_ << " successful" << std::endl;
}
}
}
}
// Send an interrupt message to each child.
for (auto msg_qid : msg_send_lists_) {
CachePerfMsg msg;
msg.SetType(CachePerfMsg::MessageType::kInterrupt);
(void)msg.Send(msg_qid);
}
// Wait for each child to return
for (auto pid : pid_lists_) {
int status;
if (waitpid(pid, &status, 0) == -1) {
std::string errMsg = "waitpid fails. errno = " + std::to_string(errno);
std::cerr << errMsg << std::endl;
} else {
MS_LOG(INFO) << "Child pid " << pid << " returns." << std::endl;
}
}
// Remove all the message queues
for (auto msg_qid : msg_send_lists_) {
// Remove the message que and never mind about the return code.
(void)msgctl(msg_qid, IPC_RMID, nullptr);
}
for (auto msg_qid : msg_recv_lists_) {
// Remove the message que and never mind about the return code.
(void)msgctl(msg_qid, IPC_RMID, nullptr);
}
}

void CachePerfRun::PrintEpochSummary() const {
std::cout << std::setw(12) << "Pipeline #" << std::setw(10) << "worker id" << std::setw(11) << "min (μs)"
<< std::setw(11) << "max (μs)" << std::setw(11) << "avg (μs)" << std::setw(14) << "median (μs)"
<< std::setw(14) << "buffer count" << std::setw(18) << "Elapsed time (s)" << std::endl;
for (auto &it : epoch_results_) {
auto epoch_worker_summary = it.second;
std::cout << std::setw(12) << epoch_worker_summary.pipeline() + 1 << std::setw(10) << epoch_worker_summary.worker()
<< std::setw(10) << epoch_worker_summary.min() << std::setw(10) << epoch_worker_summary.max()
<< std::setw(10) << epoch_worker_summary.avg() << std::setw(13) << epoch_worker_summary.med()
<< std::setw(14) << epoch_worker_summary.cnt() << std::setw(18) << epoch_worker_summary.elapse()
<< std::endl;
}
}

Status CachePerfRun::ListenToPipeline(int32_t workerId) {
TaskManager::FindMe()->Post();
int32_t qID = msg_recv_lists_[workerId];
do {
RETURN_IF_INTERRUPTED();
CachePerfMsg msg;
RETURN_IF_NOT_OK(msg.Receive(qID));
// Decode the messages.
auto type = msg.GetType();
char *p = msg.GetMutableBuffer();
switch (type) {
case CachePerfMsg::MessageType::kEpochResult: {
PipelineWorkerEpochSummary epoch_worker_summary;
CHECK_FAIL_RETURN_UNEXPECTED(epoch_worker_summary.ParseFromArray(p, msg.GetProtoBufSz()), "Parse fail");
{
auto pipeline = epoch_worker_summary.pipeline();
auto worker = epoch_worker_summary.worker();
std::unique_lock<std::mutex> lock(mux_);
// sort by pipeline/worker
auto r =
epoch_results_.emplace(std::pair<int32_t, int32_t>(pipeline, worker), std::move(epoch_worker_summary));
CHECK_FAIL_RETURN_UNEXPECTED(r.second, "Insert failed");
}
break;
}
case CachePerfMsg::MessageType::kEpochEnd: {
EpochDone proto;
CHECK_FAIL_RETURN_UNEXPECTED(proto.ParseFromArray(p, msg.GetProtoBufSz()), "Parse fail");
auto n = epoch_sync_cnt_.fetch_add(1);
if (n + 1 == num_pipelines_) {
pipeline_wp_.Set();
}
break;
}
case CachePerfMsg::MessageType::kInterrupt: {
TaskManager::WakeUpWatchDog();
return Status::OK();
}
case CachePerfMsg::kError: {
ErrorMsg proto;
CHECK_FAIL_RETURN_UNEXPECTED(proto.ParseFromArray(p, msg.GetProtoBufSz()), "Parse fail");
return Status(static_cast<const StatusCode>(proto.rc()), proto.msg());
}
default:
std::string errMsg = "Unknown request type: " + std::to_string(type);
MS_LOG(ERROR) << errMsg;
RETURN_STATUS_UNEXPECTED(errMsg);
break;
}
} while (true);
return Status::OK();
}

Status CachePerfRun::Run() {
// Now we bring up TaskManager.
RETURN_IF_NOT_OK(Services::CreateInstance());
// Handle Control-C
RegisterHandlers();

// Get a session from the server.
RETURN_IF_NOT_OK(GetSession());

// Generate a random crc.
auto mt = GetRandomDevice();
std::uniform_int_distribution<session_id_type> distribution(0, std::numeric_limits<int32_t>::max());
crc_ = distribution(mt);
std::cout << "CRC: " << crc_ << std::endl;

// Create all the resources required by the pipelines before we fork.
for (auto i = 0; i < num_pipelines_; ++i) {
// We will use shared message queues for communication between parent (this process)
// and each pipelines.
auto access_mode = S_IRUSR | S_IWUSR;
int32_t msg_send_qid = msgget(IPC_PRIVATE, IPC_CREAT | IPC_EXCL | access_mode);
if (msg_send_qid == -1) {
std::string errMsg = "Unable to create a message queue. Errno = " + std::to_string(errno);
RETURN_STATUS_UNEXPECTED(errMsg);
}
msg_send_lists_.push_back(msg_send_qid);
int32_t msg_recv_qid = msgget(IPC_PRIVATE, IPC_CREAT | IPC_EXCL | access_mode);
if (msg_recv_qid == -1) {
std::string errMsg = "Unable to create a message queue. Errno = " + std::to_string(errno);
RETURN_STATUS_UNEXPECTED(errMsg);
}
msg_recv_lists_.push_back(msg_recv_qid);
}

// Now we create the children knowing all two sets of message queues are constructed.
for (auto i = 0; i < num_pipelines_; ++i) {
auto pid = fork();
if (pid == 0) {
// Child. We will call another binary but with different (hidden) parameters.
// The parent process is waiting on a wait post. Any error we hit here we must interrupt the
// parent process
auto interrupt_parent = [this, i]() {
CachePerfMsg msg;
msg.SetType(CachePerfMsg::MessageType::kInterrupt);
msg.Send(msg_recv_lists_[i]);
};
const std::string self_proc = "/proc/self/exe";
std::string canonical_path;
canonical_path.resize(400); // PATH_MAX is large. This value should be big enough for our use.
// Some lower level OS library calls are needed here to determine the binary path.
if (realpath(self_proc.data(), canonical_path.data()) == nullptr) {
std::cerr << "Failed to identify cache_perf binary path: " + std::to_string(errno) << ": " << strerror(errno)
<< std::endl;
interrupt_parent();
// Call _exit instead of exit because we will hang in TaskManager destructor for a forked child process.
_exit(-1);
}
canonical_path.resize(strlen(canonical_path.data()));
int last_seperator = canonical_path.find_last_of('/');
if (last_seperator == std::string::npos) {
std::cerr << "Canonical path can't locate / " << canonical_path << std::endl;
interrupt_parent();
// Call _exit instead of exit because we will hang in TaskManager destructor for a forked child process.
_exit(-1);
}
// truncate the binary name so we are left with the absolute path of cache_admin binary
canonical_path.resize(last_seperator + 1);
std::string cache_pipeline_binary = canonical_path + std::string(kCachePipelineBinary);

std::string pipeline_cfg = std::to_string(i) + "," + std::to_string(session_) + "," + std::to_string(crc_) + "," +
std::to_string(msg_send_lists_[i]) + "," + std::to_string(msg_recv_lists_[i]) + "," +
std::to_string(num_pipelines_) + "," + std::to_string(num_epoches_) + "," +
std::to_string(num_rows_) + "," + std::to_string(row_size_) + "," +
std::to_string(cfg_.num_parallel_workers()) + "," +
(shuffle_ ? std::string("true").data() : std::string("false").data());
std::string client_cfg = cache_builder_.GetHostname() + "," + std::to_string(cache_builder_.GetPort()) + "," +
std::to_string(cache_builder_.GetPrefetchSize()) + "," +
std::to_string(cache_builder_.GetCacheMemSz()) + "," +
std::to_string(cache_builder_.GetNumConnections()) + "," +
(cache_builder_.isSpill() ? std::string("true").data() : std::string("false").data());
char *argv[4];
argv[0] = const_cast<char *>(kCachePipelineBinary);
argv[1] = pipeline_cfg.data();
argv[2] = client_cfg.data();
argv[3] = nullptr;
// Invoke the binary.
execv(cache_pipeline_binary.data(), argv);
std::cerr << "Unable to exec. Errno = " + std::to_string(errno) << ": " << strerror(errno) << std::endl;
interrupt_parent();
// Call _exit instead of exit because we will hang TaskManager destructor for a forked child process.
_exit(-1);
} else if (pid > 0) {
std::cout << "Pipeline number " << i + 1 << " has been created with process id: " << pid << std::endl;
pid_lists_.push_back(pid);
} else {
std::string errMsg = "Failed to fork process for cache pipeline: " + std::to_string(errno);
RETURN_STATUS_UNEXPECTED(errMsg);
}
}

// Spawn a few threads to monitor the communications from the pipeline.
RETURN_IF_NOT_OK(vg_.ServiceStart());

auto f = std::bind(&CachePerfRun::ListenToPipeline, this, std::placeholders::_1);
for (auto i = 0; i < num_pipelines_; ++i) {
RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Queue listener", std::bind(f, i)));
}

// Wait until all pipelines finish the first epoch.
RETURN_IF_NOT_OK(pipeline_wp_.Wait());

std::cout << "Epoch one (build phase) per pipeline per worker summary. Buffer size = " << cfg_.rows_per_buffer()
<< std::endl;
PrintEpochSummary();

// Get some stat but we need to connect. The server will thinks it is the (n+1) pipeline
RETURN_IF_NOT_OK(cache_builder_.Build(&cc_));
Status rc = cc_->CreateCache(crc_, false);
// Duplicate key is fine.
if (rc.IsError() && rc.get_code() != StatusCode::kDuplicateKey) {
return rc;
}

CacheServiceStat stat{};
RETURN_IF_NOT_OK(cc_->GetStat(&stat));

std::cout << "Get statistics for this session:\n";
std::cout << std::setw(12) << "Mem cached" << std::setw(12) << "Disk cached" << std::setw(16) << "Avg cache size"
<< std::setw(10) << "Numa hit" << std::endl;
std::string stat_mem_cached;
std::string stat_disk_cached;
std::string stat_avg_cached;
std::string stat_numa_hit;
stat_mem_cached = (stat.num_mem_cached == 0) ? "n/a" : std::to_string(stat.num_mem_cached);
stat_disk_cached = (stat.num_disk_cached == 0) ? "n/a" : std::to_string(stat.num_disk_cached);
stat_avg_cached = (stat.avg_cache_sz == 0) ? "n/a" : std::to_string(stat.avg_cache_sz);
stat_numa_hit = (stat.num_numa_hit == 0) ? "n/a" : std::to_string(stat.num_numa_hit);

std::cout << std::setw(12) << stat_mem_cached << std::setw(12) << stat_disk_cached << std::setw(16) << stat_avg_cached
<< std::setw(10) << stat_numa_hit << std::endl;

// Toggle write mode off since the rest are just read only.
// Simplest way is call this special internal function.
cc_->ServerRunningOutOfResources();

// The rest of the epochs are just fetching.
auto epoch_num = 2;
while (epoch_num <= num_epoches_) {
epoch_sync_cnt_ = 0;
pipeline_wp_.Clear();
epoch_results_.clear();
// Signal each pipeline to start
for (auto msg_qid : msg_send_lists_) {
CachePerfMsg msg;
msg.SetType(CachePerfMsg::MessageType::kEpochStart);
(void)msg.Send(msg_qid);
}
// Wait for the child to finish
RETURN_IF_NOT_OK(pipeline_wp_.Wait());
std::cout << "Epoch " << epoch_num
<< " (read phase) per pipeline per worker summary. Buffer size = " << cc_->GetPrefetchSize() << std::endl;
PrintEpochSummary();
++epoch_num;
}

// Destroy the cache. We no longer need it around.
RETURN_IF_NOT_OK(cc_->DestroyCache());

// Unreserve the session
CacheClientInfo cinfo;
cinfo.set_session_id(session_);
auto rq = std::make_shared<DropSessionRequest>(cinfo);
RETURN_IF_NOT_OK(cc_->PushRequest(rq));
RETURN_IF_NOT_OK(rq->Wait());
std::cout << "Drop session " << session_ << " successful" << std::endl;
session_ = 0;

return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 100
- 0
mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_perf_run.h View File

@@ -0,0 +1,100 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PERF_RUN_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PERF_RUN_H_

#include <getopt.h>
#include <atomic>
#include <cstdint>
#include <limits>
#include <iostream>
#include <map>
#include <memory>
#include <mutex>
#include <random>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/engine/datasetops/source/io_block.h"
#include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/engine/cache/perf/cache_msg.h"
#include "minddata/dataset/util/queue.h"
#include "minddata/dataset/util/task_manager.h"
#include "minddata/dataset/util/wait_post.h"

namespace mindspore {
namespace dataset {

constexpr int32_t kDftNumOfPipelines = 8;
constexpr int32_t kDftNumberOfEpochs = 10;
constexpr int32_t kDftCacheSize = 0;
constexpr bool kDftShuffle = false;
constexpr bool kDftSpill = false;

class CachePerfRun {
public:
static const char kCachePipelineBinary[];
CachePerfRun();
~CachePerfRun();
void PrintHelp();
int32_t ProcessArgs(int argc, char **argv);

void Print(std::ostream &out) const {
out << "Number of pipelines: " << num_pipelines_ << "\n"
<< "Number of epochs: " << num_epoches_ << "\n"
<< "Sample size: " << num_rows_ << "\n"
<< "Average row size: " << row_size_ << "\n"
<< "Shuffle: " << std::boolalpha << shuffle_;
}

friend std::ostream &operator<<(std::ostream &out, const CachePerfRun &cp) {
cp.Print(out);
return out;
}

Status Run();

private:
std::mutex mux_;
int32_t my_pipeline_;
int32_t num_pipelines_;
int32_t num_epoches_;
int64_t num_rows_;
int32_t row_size_;
bool shuffle_;
CacheClient::Builder cache_builder_;
session_id_type session_;
int32_t crc_;
std::vector<int32_t> pid_lists_;
std::vector<int32_t> msg_send_lists_;
std::vector<int32_t> msg_recv_lists_;
TaskGroup vg_;
std::atomic<int32_t> epoch_sync_cnt_;
WaitPost pipeline_wp_;
std::map<std::pair<int32_t, int32_t>, PipelineWorkerEpochSummary> epoch_results_;
ConfigManager cfg_;
std::shared_ptr<CacheClient> cc_;

Status GetSession();
Status ListenToPipeline(int32_t workerId);
void PrintEpochSummary() const;
};
} // namespace dataset
} // namespace mindspore

#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PERF_RUN_H_

+ 44
- 0
mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_pipeline.cc View File

@@ -0,0 +1,44 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifdef USE_GLOG
#include <glog/logging.h>
#endif
#include <string.h>
#include "minddata/dataset/engine/cache/perf/cache_pipeline_run.h"
namespace ds = mindspore::dataset;

int main(int argc, char **argv) {
#ifdef USE_GLOG
FLAGS_log_dir = "/tmp";
FLAGS_minloglevel = google::WARNING;
google::InitGoogleLogging(argv[0]);
#endif
ds::CachePipelineRun cachePipelineRun;
if (cachePipelineRun.ProcessArgs(argc, argv) == 0) {
ds::Status rc = cachePipelineRun.Run();
// If we hit any error, send the rc back to the parent.
if (rc.IsError()) {
ds::ErrorMsg proto;
proto.set_rc(static_cast<int32_t>(rc.get_code()));
proto.set_msg(rc.ToString());
ds::CachePerfMsg msg;
(void)cachePipelineRun.SendMessage(&msg, ds::CachePerfMsg::MessageType::kError, &proto);
}
return static_cast<int>(rc.get_code());
}
return 0;
}

+ 471
- 0
mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_pipeline_run.cc View File

@@ -0,0 +1,471 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "minddata/dataset/engine/cache/perf/cache_pipeline_run.h"
#include <string.h>
#include <sys/types.h>
#include <algorithm>
#include <chrono>
#include <iomanip>
#include <sstream>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/data_buffer.h"
#include "minddata/dataset/engine/data_schema.h"
#include "minddata/dataset/util/random.h"
#include "minddata/dataset/util/services.h"

namespace mindspore {
namespace dataset {
void CachePipelineRun::PrintHelp() { std::cout << "Please run the executable cache_perf instead." << std::endl; }

int32_t CachePipelineRun::ProcessArgs(int argc, char **argv) {
if (argc != 3) {
PrintHelp();
return -1;
}

try {
std::stringstream cfg_ss(argv[1]);
std::string s;
int32_t numArgs = 0;
while (std::getline(cfg_ss, s, ',')) {
if (numArgs == 0) {
my_pipeline_ = std::stoi(s);
} else if (numArgs == 1) {
session_ = std::stoul(s);
cache_builder_.SetSessionId(session_);
} else if (numArgs == 2) {
crc_ = std::stoi(s);
} else if (numArgs == 3) {
recv_id_ = std::stoi(s);
} else if (numArgs == 4) {
send_id_ = std::stoi(s);
} else if (numArgs == 5) {
num_pipelines_ = std::stoi(s);
} else if (numArgs == 6) {
num_epoches_ = std::stoi(s);
} else if (numArgs == 7) {
num_rows_ = std::stol(s);
} else if (numArgs == 8) {
row_size_ = std::stoi(s);
} else if (numArgs == 9) {
cfg_.set_num_parallel_workers(std::stol(s));
} else if (numArgs == 10) {
shuffle_ = strcmp(s.data(), "true") == 0;
}
++numArgs;
}
if (numArgs != 11) {
std::cerr << "Incomplete arguments. Expect 11. But get " << numArgs << std::endl;
return -1;
}
std::stringstream client_ss(argv[2]);
numArgs = 0;
while (std::getline(client_ss, s, ',')) {
if (numArgs == 0) {
cache_builder_.SetHostname(s);
} else if (numArgs == 1) {
cache_builder_.SetPort(std::stoi(s));
} else if (numArgs == 2) {
cache_builder_.SetPrefetchSize(std::stoi(s));
} else if (numArgs == 3) {
cache_builder_.SetCacheMemSz(std::stoi(s));
} else if (numArgs == 4) {
cache_builder_.SetNumConnections(std::stoi(s));
} else if (numArgs == 5) {
cache_builder_.SetSpill(strcmp(s.data(), "true") == 0);
}
++numArgs;
}
if (numArgs != 6) {
std::cerr << "Incomplete arguments. Expect 6. But get " << numArgs << std::endl;
return -1;
}
} catch (const std::exception &e) {
std::cerr << "Parse error: " << e.what() << std::endl;
return -1;
}
return 0;
}

CachePipelineRun::CachePipelineRun()
: my_pipeline_(-1),
num_pipelines_(kDftNumOfPipelines),
num_epoches_(kDftNumberOfEpochs),
num_rows_(0),
row_size_(0),
shuffle_(kDftShuffle),
session_(0),
crc_(0),
send_id_(-1),
recv_id_(-1),
start_row_(-1),
end_row_(-1) {
cache_builder_.SetSpill(kDftSpill).SetCacheMemSz(kDftCacheSize);
}

CachePipelineRun::~CachePipelineRun() {
CachePerfMsg msg;
(void)SendMessage<ErrorMsg>(&msg, CachePerfMsg::MessageType::kInterrupt, nullptr);
}

Status CachePipelineRun::ListenToParent() {
TaskManager::FindMe()->Post();
do {
RETURN_IF_INTERRUPTED();
CachePerfMsg msg;
RETURN_IF_NOT_OK(msg.Receive(recv_id_));
// Decode the messages.
auto type = msg.GetType();
switch (type) {
case CachePerfMsg::MessageType::kInterrupt: {
TaskManager::WakeUpWatchDog();
return Status::OK();
}
case CachePerfMsg::MessageType::kEpochStart: {
pipeline_wp_.Set();
break;
}
default:
std::string errMsg = "Unknown request type: " + std::to_string(type);
MS_LOG(ERROR) << errMsg;
RETURN_STATUS_UNEXPECTED(errMsg);
break;
}
} while (true);

return Status::OK();
}

Status CachePipelineRun::Run() {
RETURN_IF_NOT_OK(cache_builder_.Build(&cc_));
RETURN_IF_NOT_OK(vg_.ServiceStart());

auto num_workers = cfg_.num_parallel_workers();
io_block_queues_.Init(num_workers, cfg_.op_connector_size());

RETURN_IF_NOT_OK(io_block_queues_.Register(&vg_));

Status rc = cc_->CreateCache(crc_, false);
// Duplicate key is fine.
if (rc.IsError() && rc.get_code() != StatusCode::kDuplicateKey) {
return rc;
}

// Log a warning level message so we can see it in the log file when it starts.
MS_LOG(WARNING) << "Pipeline number " << my_pipeline_ + 1 << " successfully creating cache service." << std::endl;

// Spawn a thread to listen to the parent process
RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Queue listener", std::bind(&CachePipelineRun::ListenToParent, this)));

RETURN_IF_NOT_OK(RunFirstEpoch());

// The rest of the epochs are just fetching.
auto remaining_epochs = num_epoches_ - 1;
while (remaining_epochs > 0) {
// Wait for parent process signal to start
pipeline_wp_.Wait();
pipeline_wp_.Clear();
RETURN_IF_NOT_OK(RunReadEpoch());
--remaining_epochs;
}

// The listener thread is blocked on a shared message queue. It will be waken up by
// the parent process which will send an interrupt message to it, and this program
// will exit.
RETURN_IF_NOT_OK(vg_.join_all());
return Status::OK();
}

Status CachePipelineRun::RunFirstEpoch() {
auto sz = num_rows_ / num_pipelines_;
start_row_ = my_pipeline_ * sz;
end_row_ = (my_pipeline_ + 1) * sz - 1;
if (my_pipeline_ + 1 == num_pipelines_) {
end_row_ = num_rows_ - 1;
}
std::cout << "Pipeline number " << my_pipeline_ + 1 << " row id range: [" << start_row_ << "," << end_row_ << "]"
<< std::endl;

// Spawn the worker threads.
auto f = std::bind(&CachePipelineRun::WriterWorkerEntry, this, std::placeholders::_1);
std::vector<Task *> worker_threads;
auto num_workers = cfg_.num_parallel_workers();
worker_threads.reserve(num_workers);
for (int32_t i = 0; i < num_workers; ++i) {
Task *pTask;
RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Parallel worker", std::bind(f, i), &pTask));
worker_threads.push_back(pTask);
}

std::vector<row_id_type> keys;
auto rows_per_buffer = cfg_.rows_per_buffer();
keys.reserve(rows_per_buffer);
int32_t worker_id = 0;
for (auto i = start_row_; i <= end_row_; ++i) {
keys.push_back(i);
if (keys.size() == rows_per_buffer) {
auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));
RETURN_IF_NOT_OK(io_block_queues_[worker_id++ % num_workers]->Add(std::move(blk)));
keys.clear();
}
}
if (!keys.empty()) {
auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));
RETURN_IF_NOT_OK(io_block_queues_[worker_id++ % num_workers]->Add(std::move(blk)));
keys.clear();
}

// Shutdown threads and wait for them to come back.
for (int32_t i = 0; i < num_workers; i++) {
RETURN_IF_NOT_OK(
io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone)));
}
for (auto *pTask : worker_threads) {
RETURN_IF_NOT_OK(pTask->Join(Task::WaitFlag::kBlocking));
}

// Send a message saying epoch one done for this pipeline.
EpochDone proto;
proto.set_pipeline(my_pipeline_);
CachePerfMsg msg;
RETURN_IF_NOT_OK(SendMessage(&msg, CachePerfMsg::MessageType::kEpochEnd, &proto));

return Status::OK();
}

Status CachePipelineRun::WriterWorkerEntry(int32_t worker_id) {
Status rc;
TaskManager::FindMe()->Post();
int64_t min_val = std::numeric_limits<int64_t>::max();
int64_t max_val = 0;
int64_t total_val = 0;
int64_t cnt = 0;
std::vector<int64_t> duration;
duration.reserve(num_rows_ / num_pipelines_ / cfg_.num_parallel_workers());
bool resource_err = false;
auto col_desc = std::make_unique<ColDescriptor>("int64", DataType(DataType::DE_INT64), TensorImpl::kFlexible, 1);
auto num_elements = row_size_ / sizeof(int64_t);
TensorShape shape(std::vector<dsize_t>(1, num_elements));
std::unique_ptr<IOBlock> blk;
auto epoch_start = std::chrono::steady_clock::now();
do {
RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&blk));
std::vector<int64_t> keys;
RETURN_IF_NOT_OK(blk->GetKeys(&keys));
if (keys.empty()) {
// empty key is a quit signal for workers
break;
}
// Once we hit resource error, we drain the io block. No point to send anything to the server.
if (!resource_err) {
auto buffer = std::make_unique<DataBuffer>(cnt++, DataBuffer::kDeBFlagNone);
auto tensor_table = std::make_unique<TensorQTable>();
for (auto id : keys) {
TensorRow row;
std::shared_ptr<Tensor> element;
RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, col_desc->type(), &element));
row.setId(id);
// CreateEmpty allocates the memory but in virutal address. Let's commit the memory
// so we can get an accurate timing.
auto it = element->begin<int64_t>();
for (auto i = 0; i < num_elements; ++i, ++it) {
*it = i;
}
row.push_back(std::move(element));
tensor_table->push_back(std::move(row));
}
buffer->set_tensor_table(std::move(tensor_table));
// Measure the time to call WriteBuffer
auto start_tick = std::chrono::steady_clock::now();
rc = cc_->WriteBuffer(std::move(buffer));
auto end_tick = std::chrono::steady_clock::now();
if (rc.IsError()) {
if (rc.IsOutofMemory() || rc.IsNoSpace()) {
MS_LOG(WARNING) << "Pipeline number " << my_pipeline_ + 1 << " worker id " << worker_id << ": "
<< rc.ToString();
resource_err = true;
cc_->ServerRunningOutOfResources();
continue;
} else {
return rc;
}
} else {
int64_t ms = std::chrono::duration_cast<std::chrono::microseconds>(end_tick - start_tick).count();
min_val = std::min(min_val, ms);
max_val = std::max(max_val, ms);
duration.push_back(ms);
total_val += ms;
}
}
} while (true);

auto epoch_end = std::chrono::steady_clock::now();
int64_t elapse_time = std::chrono::duration_cast<std::chrono::seconds>(epoch_end - epoch_start).count();

PipelineWorkerEpochSummary proto;
proto.set_pipeline(my_pipeline_);
proto.set_worker(worker_id);
proto.set_min(min_val);
proto.set_max(max_val);
proto.set_elapse(elapse_time);
auto sz = duration.size();
proto.set_cnt(sz);
if (sz > 0) {
// median
auto n = sz / 2;
std::nth_element(duration.begin(), duration.begin() + n, duration.end());
auto median = duration[n];
proto.set_med(median);
// average
int64_t avg = total_val / sz;
proto.set_avg(avg);
}
CachePerfMsg msg;
RETURN_IF_NOT_OK(SendMessage(&msg, CachePerfMsg::MessageType::kEpochResult, &proto));
return Status::OK();
}

Status CachePipelineRun::RunReadEpoch() {
std::vector<row_id_type> keys;
auto rows_per_buffer = cc_->GetPrefetchSize(); // We will use prefetch size to read.
auto num_workers = cfg_.num_parallel_workers();
keys.reserve(rows_per_buffer);
// Spawn workers
auto f = std::bind(&CachePipelineRun::ReaderWorkerEntry, this, std::placeholders::_1);
std::vector<Task *> worker_threads;
worker_threads.reserve(num_workers);
for (int32_t i = 0; i < num_workers; ++i) {
Task *pTask;
RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Parallel worker", std::bind(f, i), &pTask));
worker_threads.push_back(pTask);
}

std::vector<row_id_type> all_keys;
all_keys.reserve(end_row_ - start_row_ + 1);
for (auto i = start_row_; i <= end_row_; ++i) {
all_keys.push_back((i));
}
// If we need to shuffle the keys
if (shuffle_) {
std::shuffle(all_keys.begin(), all_keys.end(), GetRandomDevice());
}

int32_t worker_id = 0;
for (auto id : all_keys) {
keys.push_back(id);
if (keys.size() == rows_per_buffer) {
auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));
RETURN_IF_NOT_OK(io_block_queues_[worker_id++ % num_workers]->Add(std::move(blk)));
keys.clear();
}
}
if (!keys.empty()) {
auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));
RETURN_IF_NOT_OK(io_block_queues_[worker_id++ % num_workers]->Add(std::move(blk)));
keys.clear();
}

// Shutdown threads and wait for them to come back.
for (int32_t i = 0; i < num_workers; i++) {
RETURN_IF_NOT_OK(
io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone)));
}
for (auto *pTask : worker_threads) {
RETURN_IF_NOT_OK(pTask->Join(Task::WaitFlag::kBlocking));
}

// Send a message saying epoch one done for this pipeline.
EpochDone proto;
proto.set_pipeline(my_pipeline_);
CachePerfMsg msg;
RETURN_IF_NOT_OK(SendMessage(&msg, CachePerfMsg::MessageType::kEpochEnd, &proto));
return Status::OK();
}

Status CachePipelineRun::ReaderWorkerEntry(int32_t worker_id) {
Status rc;
TaskManager::FindMe()->Post();
int64_t min_val = std::numeric_limits<int64_t>::max();
int64_t max_val = 0;
int64_t total_val = 0;
int64_t cnt = 0;
std::vector<int64_t> duration;
duration.reserve(num_rows_ / num_pipelines_ / cfg_.num_parallel_workers());
std::unique_ptr<IOBlock> blk;
auto epoch_start = std::chrono::steady_clock::now();
do {
RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&blk));
std::vector<int64_t> keys;
RETURN_IF_NOT_OK(blk->GetKeys(&keys));
if (keys.empty()) {
// empty key is a quit signal for workers
break;
}
std::vector<row_id_type> prefetch_keys;
prefetch_keys.reserve(keys.size());

// Filter out all those keys that unlikely we will find at the server
for (auto row_id : keys) {
if (!cc_->KeyIsCacheMiss(row_id)) {
prefetch_keys.push_back(row_id);
}
}
// Early exit if nothing to fetch
if (prefetch_keys.empty()) {
continue;
}
// Get the rows from the server
TensorTable ttbl;
// Measure how long it takes for the row to come back.
auto start_tick = std::chrono::steady_clock::now();
RETURN_IF_NOT_OK(cc_->GetRows(prefetch_keys, &ttbl));
auto end_tick = std::chrono::steady_clock::now();
int64_t ms = std::chrono::duration_cast<std::chrono::microseconds>(end_tick - start_tick).count();
min_val = std::min(min_val, ms);
max_val = std::max(max_val, ms);
duration.push_back(ms);
total_val += ms;
++cnt;
} while (true);

auto epoch_end = std::chrono::steady_clock::now();
int64_t elapse_time = std::chrono::duration_cast<std::chrono::seconds>(epoch_end - epoch_start).count();

PipelineWorkerEpochSummary proto;
proto.set_pipeline(my_pipeline_);
proto.set_worker(worker_id);
proto.set_min(min_val);
proto.set_max(max_val);
proto.set_elapse(elapse_time);
auto sz = duration.size();
proto.set_cnt(sz);
if (sz > 0) {
// median
auto n = sz / 2;
std::nth_element(duration.begin(), duration.begin() + n, duration.end());
auto median = duration[n];
proto.set_med(median);
// average
int64_t avg = total_val / sz;
proto.set_avg(avg);
}
CachePerfMsg msg;
RETURN_IF_NOT_OK(SendMessage(&msg, CachePerfMsg::MessageType::kEpochResult, &proto));
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 117
- 0
mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_pipeline_run.h View File

@@ -0,0 +1,117 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PIPELINE_RUN_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PIPELINE_RUN_H_

#include <getopt.h>
#include <atomic>
#include <cstdint>
#include <limits>
#include <iostream>
#include <map>
#include <memory>
#include <mutex>
#include <random>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/engine/datasetops/source/io_block.h"
#include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/engine/cache/perf/cache_msg.h"
#include "minddata/dataset/util/queue.h"
#include "minddata/dataset/util/task_manager.h"
#include "minddata/dataset/util/wait_post.h"

namespace mindspore {
namespace dataset {

constexpr int32_t kDftNumOfPipelines = 8;
constexpr int32_t kDftNumberOfEpochs = 10;
constexpr int32_t kDftCacheSize = 0;
constexpr bool kDftShuffle = false;
constexpr bool kDftSpill = false;

class CachePipelineRun {
public:
CachePipelineRun();
~CachePipelineRun();
static void PrintHelp();
int32_t ProcessArgs(int argc, char **argv);

void Print(std::ostream &out) const {
out << "Number of pipelines: " << num_pipelines_ << "\n"
<< "Number of epochs: " << num_epoches_ << "\n"
<< "Sample size: " << num_rows_ << "\n"
<< "Average row size: " << row_size_ << "\n"
<< "Shuffle: " << std::boolalpha << shuffle_;
}

friend std::ostream &operator<<(std::ostream &out, const CachePipelineRun &cp) {
cp.Print(out);
return out;
}

Status Run();

template <typename T>
Status SendMessage(CachePerfMsg *msg, CachePerfMsg::MessageType type, T *proto) {
RETURN_UNEXPECTED_IF_NULL(msg);
msg->SetType(type);
if (proto != nullptr) {
auto size_needed = proto->ByteSizeLong();
CHECK_FAIL_RETURN_UNEXPECTED(
size_needed <= kSharedMessageSize,
"Shared message set too small. Suggest to increase to " + std::to_string(size_needed));
CHECK_FAIL_RETURN_UNEXPECTED(proto->SerializeToArray(msg->GetMutableBuffer(), kSharedMessageSize),
"Serialization fails");
msg->SetProtoBufSz(size_needed);
}
RETURN_IF_NOT_OK(msg->Send(send_id_));
return Status::OK();
}

private:
int32_t my_pipeline_;
int32_t num_pipelines_;
int32_t num_epoches_;
int64_t num_rows_;
int32_t row_size_;
bool shuffle_;
CacheClient::Builder cache_builder_;
session_id_type session_;
int32_t crc_;
TaskGroup vg_;
WaitPost pipeline_wp_;
int32_t send_id_;
int32_t recv_id_;
row_id_type start_row_;
row_id_type end_row_;
ConfigManager cfg_;
QueueList<std::unique_ptr<IOBlock>> io_block_queues_; // queues of IOBlocks
std::shared_ptr<CacheClient> cc_;

Status ListenToParent();
Status RunFirstEpoch();
Status RunReadEpoch();
Status WriterWorkerEntry(int32_t worker_id);
Status ReaderWorkerEntry(int32_t worker_id);
};
} // namespace dataset
} // namespace mindspore

#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PIPELINE_RUN_H_

mindspore/ccsrc/minddata/dataset/util/storage_container.cc → mindspore/ccsrc/minddata/dataset/engine/cache/storage_container.cc View File

@@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/util/storage_container.h"
#include "minddata/dataset/engine/cache/storage_container.h"

#include <fcntl.h>
#include <sys/stat.h>

mindspore/ccsrc/minddata/dataset/util/storage_container.h → mindspore/ccsrc/minddata/dataset/engine/cache/storage_container.h View File


mindspore/ccsrc/minddata/dataset/util/storage_manager.cc → mindspore/ccsrc/minddata/dataset/engine/cache/storage_manager.cc View File

@@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/util/storage_manager.h"
#include "minddata/dataset/engine/cache/storage_manager.h"

#include <iomanip>
#include <sstream>

mindspore/ccsrc/minddata/dataset/util/storage_manager.h → mindspore/ccsrc/minddata/dataset/engine/cache/storage_manager.h View File

@@ -21,6 +21,7 @@
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/cache/storage_container.h"
#include "minddata/dataset/util/allocator.h"
#include "minddata/dataset/util/auto_index.h"
#include "minddata/dataset/util/lock.h"
@@ -28,7 +29,6 @@
#include "minddata/dataset/util/path.h"
#include "minddata/dataset/util/service.h"
#include "minddata/dataset/util/slice.h"
#include "minddata/dataset/util/storage_container.h"

using ListOfContainers = std::vector<std::shared_ptr<mindspore::dataset::StorageContainer>>;
namespace mindspore {

+ 19
- 20
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc View File

@@ -271,29 +271,18 @@ Status CacheBase::PrefetchRows(const std::vector<row_id_type> &keys, std::vector
}
// Get the rows from the server
TensorTable ttbl;
Status rc = cache_client_->GetRows(prefetch_keys, &ttbl);
if (rc.IsOk()) {
auto row_it = ttbl.begin();
for (auto row_id : prefetch_keys) {
auto &row = *row_it;
if (row.empty()) {
cache_miss->push_back(row_id);
}
// Put the prefetch row into the pool and wake up any WorkerEntry to wait for the row
RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row)));
++row_it;
}
} else {
// In case any thread is waiting for the rows to come back and blocked on a semaphore,
// we will put an empty row in the local cache.
for (auto row_id : prefetch_keys) {
TensorRow row;
row.setId(row_id);
RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row)));
RETURN_IF_NOT_OK(cache_client_->GetRows(prefetch_keys, &ttbl));
auto row_it = ttbl.begin();
for (auto row_id : prefetch_keys) {
auto &row = *row_it;
if (row.empty()) {
cache_miss->push_back(row_id);
}
// Put the prefetch row into the pool and wake up any WorkerEntry to wait for the row
RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row)));
++row_it;
}
return rc;
return Status::OK();
}

Status CacheBase::Prefetcher(int32_t worker_id) {
@@ -322,6 +311,16 @@ Status CacheBase::Prefetcher(int32_t worker_id) {
return rc;
}
} while (rc.IsNetWorkError());
// In case any thread is waiting for the rows to come back and blocked on a semaphore,
// we will put an empty row in the local cache.
if (rc.IsError() && AllowCacheMiss()) {
for (auto row_id : prefetch_keys) {
TensorRow row;
row.setId(row_id);
RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row)));
cache_miss.push_back(row_id);
}
}
} else {
if (AllowCacheMiss()) {
// This code path is for CacheLookupOp acting as a sampler. If we get a eoe from


+ 0
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h View File

@@ -24,7 +24,6 @@
#include <vector>
#include "minddata/dataset/engine/connector.h"
#include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/engine/cache/cache_service.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/repeat_op.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"


+ 1
- 2
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc View File

@@ -309,8 +309,7 @@ Status CacheMergeOp::TensorRowCacheRequest::AsyncSendCacheRequest(const std::sha
if (st_.compare_exchange_strong(expected, State::kDirty)) {
// We will do a deep copy but write directly into CacheRequest protobuf or shared memory
Status rc;
cleaner_copy_ =
std::make_shared<CacheRowRequest>(cc->server_connection_id_, cc->cookie(), cc->SupportLocalClient());
cleaner_copy_ = std::make_shared<CacheRowRequest>(cc.get());
rc = cleaner_copy_->SerializeCacheRowRequest(cc.get(), row);
if (rc.IsOk()) {
// Send the request async. The cleaner will check the return code.


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc View File

@@ -153,7 +153,7 @@ Status CacheOp::WaitForCachingAllRows() {
bool BuildPhaseDone = true;
do {
RETURN_IF_NOT_OK(cache_client_->GetStat(&stat));
BuildPhaseDone = stat.cache_service_state == static_cast<uint8_t>(CacheService::State::kFetchPhase);
BuildPhaseDone = stat.cache_service_state == static_cast<uint8_t>(CacheServiceState::kFetchPhase);
if (!BuildPhaseDone) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}


+ 77
- 1
mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.cc View File

@@ -24,7 +24,7 @@ namespace mindspore {
namespace dataset {

// Constructor
CacheErrorPass::CacheErrorPass() : is_cached_(false) {}
CacheErrorPass::CacheErrorPass() : is_cached_(false), is_mappable_(false) {}

// Identifies the subtree below this node as being cached
Status CacheErrorPass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
@@ -75,5 +75,81 @@ Status CacheErrorPass::PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modifi
return Status::OK();
}
#endif

Status CacheErrorPass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) {
// Turn on the flag that this is a tree with mappable leaf dataset
is_mappable_ = true;
return Status::OK();
}

Status CacheErrorPass::RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) {
// Turn on the flag that this is a tree with mappable leaf dataset
is_mappable_ = true;
return Status::OK();
}

Status CacheErrorPass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) {
// Turn on the flag that this is a tree with mappable leaf dataset
is_mappable_ = true;
return Status::OK();
}

Status CacheErrorPass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) {
// Turn on the flag that this is a tree with mappable leaf dataset
is_mappable_ = true;
return Status::OK();
}

Status CacheErrorPass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) {
// Turn on the flag that this is a tree with mappable leaf dataset
is_mappable_ = true;
return Status::OK();
}

Status CacheErrorPass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) {
// Turn on the flag that this is a tree with mappable leaf dataset
is_mappable_ = true;
return Status::OK();
}

Status CacheErrorPass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) {
// Turn on the flag that this is a tree with mappable leaf dataset
is_mappable_ = true;
return Status::OK();
}

Status CacheErrorPass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) {
// Turn on the flag that this is a tree with mappable leaf dataset
is_mappable_ = true;
return Status::OK();
}

Status CacheErrorPass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) {
// Turn on the flag that this is a tree with mappable leaf dataset
is_mappable_ = true;
return Status::OK();
}

Status CacheErrorPass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) {
// Turn on the flag that this is a tree with mappable leaf dataset
is_mappable_ = true;
return Status::OK();
}

Status CacheErrorPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
// Turn off the flag that we're under a merge op
is_cached_ = false;
return Status::OK();
}

// Currently, returns an error if RepeatOp exists under a cache
// Because there is no operator in the cache hit stream to consume eoes, caching above repeat causes problem.
Status CacheErrorPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
if (is_cached_ && is_mappable_) {
RETURN_STATUS_UNEXPECTED("Repeat is not supported as a descendant operator under a mappable cache.");
}

return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 73
- 0
mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.h View File

@@ -67,8 +67,81 @@ class CacheErrorPass : public NodePass {
Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified) override;
#endif

/// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override;

/// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) override;

/// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) override;

/// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) override;

/// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) override;

/// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) override;

/// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) override;

/// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) override;

/// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override;

/// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override;

/// \brief Identifies the subtree above this node as not being cached
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;

/// \brief Identifies and block repeat under cache scenarios
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override;

private:
bool is_cached_;
bool is_mappable_;
};
} // namespace dataset
} // namespace mindspore


+ 0
- 3
mindspore/ccsrc/minddata/dataset/util/CMakeLists.txt View File

@@ -3,7 +3,6 @@ set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE
add_library(utils OBJECT
arena.cc
buddy.cc
cache_pool.cc
circular_pool.cc
data_helper.cc
memory_pool.cc
@@ -16,8 +15,6 @@ add_library(utils OBJECT
lock.cc
semaphore.cc
status.cc
storage_container.cc
storage_manager.cc
slice.cc
path.cc
wait_post.cc


+ 5
- 0
mindspore/ccsrc/minddata/dataset/util/allocator.h View File

@@ -94,6 +94,11 @@ Status MakeUnique(std::unique_ptr<T[], std::function<void(T *)>> *out, C alloc,
CHECK_FAIL_RETURN_UNEXPECTED(n > 0, "size must be positive");
try {
T *data = alloc.allocate(n);
// Some of our implementation of allocator (e.g. NumaAllocator) don't throw std::bad_alloc.
// So we have to catch for null ptr
if (data == nullptr) {
return Status(StatusCode::kOutOfMemory);
}
if (!std::is_arithmetic<T>::value) {
for (auto i = 0; i < n; i++) {
std::allocator_traits<C>::construct(alloc, &(data[i]), std::forward<Args>(args)...);


+ 12
- 0
mindspore/ccsrc/minddata/dataset/util/path.h View File

@@ -78,6 +78,18 @@ class Path {

Path operator/(const char *);

bool operator==(const Path &rhs) const { return (path_ == rhs.path_); }

bool operator!=(const Path &rhs) const { return (path_ != rhs.path_); }

bool operator<(const Path &rhs) const { return (path_ < rhs.path_); }

bool operator>(const Path &rhs) const { return (path_ > rhs.path_); }

bool operator<=(const Path &rhs) const { return (path_ <= rhs.path_); }

bool operator>=(const Path &rhs) const { return (path_ >= rhs.path_); }

bool Exists();

bool IsDirectory();


+ 12
- 1
mindspore/ccsrc/minddata/dataset/util/task.cc View File

@@ -37,6 +37,11 @@ void Task::operator()() {
ss << Services::GetUniqueID();
#endif
MS_LOG(DEBUG) << my_name_ << " Thread ID " << ss.str() << " Started.";

#if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID)
native_handle_ = pthread_self();
#endif

try {
// Previously there is a timing hole where the thread is spawn but hit error immediately before we can set
// the TaskGroup pointer and register. We move the registration logic to here (after we spawn) so we can
@@ -96,7 +101,8 @@ Task::Task(const std::string &myName, const std::function<Status()> &f)
task_group_(nullptr),
is_master_(false),
running_(false),
caught_severe_exception_(false) {
caught_severe_exception_(false),
native_handle_(0) {
IntrpResource::ResetIntrpState();
wp_.ResetIntrpState();
wp_.Clear();
@@ -164,5 +170,10 @@ Status Task::OverrideInterruptRc(const Status &rc) {
}
return rc;
}

#if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID)
pthread_t Task::GetNativeHandle() const { return native_handle_; }
#endif

} // namespace dataset
} // namespace mindspore

+ 14
- 1
mindspore/ccsrc/minddata/dataset/util/task.h View File

@@ -16,6 +16,9 @@
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_TASK_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_TASK_H_

#if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID)
#include <pthread.h>
#endif
#include <chrono>
#include <exception>
#include <functional>
@@ -84,7 +87,7 @@ class Task : public IntrpResource {

std::thread::id get_id() { return id_; }

std::string MyName() { return my_name_; }
std::string MyName() const { return my_name_; }

// An operator used by std::find
bool operator==(const Task &other) const { return (this == &other); }
@@ -97,6 +100,10 @@ class Task : public IntrpResource {

static Status OverrideInterruptRc(const Status &rc);

#if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID)
pthread_t GetNativeHandle() const;
#endif

private:
mutable std::mutex mux_;
std::string my_name_;
@@ -113,6 +120,12 @@ class Task : public IntrpResource {
volatile bool running_;
volatile bool caught_severe_exception_;

#if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID)
pthread_t native_handle_;
#else
uint64_t native_handle_;
#endif

void ShutdownGroup();
TaskGroup *MyTaskGroup();
void set_task_group(TaskGroup *vg);


+ 0
- 1
tests/ut/cpp/dataset/cache_op_test.cc View File

@@ -24,7 +24,6 @@
#include "common/common.h"
#include "gtest/gtest.h"
#include "utils/log_adapter.h"
#include "minddata/dataset/util/storage_container.h" // lint !e322
#include "minddata/dataset/engine/datasetops/source/random_data_op.h"
#include "minddata/dataset/engine/data_schema.h"



+ 10
- 1
tests/ut/python/cachetests/cachetest_py.sh View File

@@ -31,7 +31,7 @@ HandleRcExit $? 1 1
export RUN_CACHE_TEST=TRUE

# Each of these tests will create session, use it, then destroy it after the test
for i in $(seq 1 6)
for i in $(seq 1 5)
do
test_name="test_cache_map_basic${i}"
GetSession
@@ -121,6 +121,12 @@ HandleRcExit $? 0 0
PytestCmd "test_cache_map.py" "test_cache_map_voc" 1
HandleRcExit $? 0 0

PytestCmd "test_cache_map.py" "test_cache_map_python_sampler" 1
HandleRcExit $? 0 0

PytestCmd "test_cache_map.py" "test_cache_map_nested_repeat"
HandleRcExit $? 0 0

# Run two parallel pipelines (sharing cache)
for i in $(seq 1 2)
do
@@ -309,6 +315,9 @@ HandleRcExit $? 0 0
PytestCmd "test_cache_nomap.py" "test_cache_nomap_textfile" 1
HandleRcExit $? 0 0

PytestCmd "test_cache_nomap.py" "test_cache_nomap_nested_repeat"
HandleRcExit $? 0 0

for i in $(seq 1 3)
do
test_name="test_cache_nomap_multiple_cache${i}"


+ 166
- 47
tests/ut/python/dataset/test_cache_map.py View File

@@ -107,49 +107,10 @@ def test_cache_map_basic2():

@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_basic3():
"""
Test a repeat under mappable cache

Cache
|
Map(decode)
|
Repeat
|
ImageFolder
"""

logger.info("Test cache basic 3")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")

some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)

# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
decode_op = c_vision.Decode()
ds1 = ds1.repeat(4)
ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
logger.info("ds1.dataset_size is ", ds1.get_dataset_size())

num_iter = 0
for _ in ds1.create_dict_iterator(num_epochs=1):
logger.info("get data from dataset")
num_iter += 1

logger.info("Number of data in ds1: {} ".format(num_iter))
assert num_iter == 8
logger.info('test_cache_basic3 Ended.\n')


@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_basic4():
"""
Test different rows result in core dump
"""
logger.info("Test cache basic 4")
logger.info("Test cache basic 3")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
@@ -171,11 +132,11 @@ def test_cache_map_basic4():

logger.info("Number of data in ds1: {} ".format(num_iter))
assert num_iter == 8
logger.info('test_cache_basic4 Ended.\n')
logger.info('test_cache_basic3 Ended.\n')


@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_basic5():
def test_cache_map_basic4():
"""
Test Map with non-deterministic TensorOps above cache

@@ -188,7 +149,7 @@ def test_cache_map_basic5():
ImageFolder

"""
logger.info("Test cache failure 5")
logger.info("Test cache basic 4")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
@@ -211,11 +172,11 @@ def test_cache_map_basic5():

logger.info("Number of data in ds1: {} ".format(num_iter))
assert num_iter == 8
logger.info('test_cache_failure5 Ended.\n')
logger.info('test_cache_basic4 Ended.\n')


@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_basic6():
def test_cache_map_basic5():
"""
Test cache as root node

@@ -223,7 +184,7 @@ def test_cache_map_basic6():
|
ImageFolder
"""
logger.info("Test cache basic 6")
logger.info("Test cache basic 5")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
@@ -239,7 +200,7 @@ def test_cache_map_basic6():

logger.info("Number of data in ds1: {} ".format(num_iter))
assert num_iter == 2
logger.info('test_cache_basic6 Ended.\n')
logger.info('test_cache_basic5 Ended.\n')


@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
@@ -502,6 +463,7 @@ def test_cache_map_failure7():
Generator

"""

def generator_1d():
for i in range(64):
yield (np.array(i),)
@@ -528,6 +490,44 @@ def test_cache_map_failure7():
logger.info('test_cache_failure7 Ended.\n')


@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_failure8():
"""
Test a repeat under mappable cache (failure)

Cache
|
Map(decode)
|
Repeat
|
ImageFolder
"""

logger.info("Test cache failure 8")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")

some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)

# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
decode_op = c_vision.Decode()
ds1 = ds1.repeat(4)
ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)

with pytest.raises(RuntimeError) as e:
num_iter = 0
for _ in ds1.create_dict_iterator(num_epochs=1):
num_iter += 1
assert "Repeat is not supported as a descendant operator under a mappable cache" in str(e.value)

assert num_iter == 0
logger.info('test_cache_failure8 Ended.\n')


@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_parameter_check():
"""
@@ -1702,6 +1702,125 @@ def test_cache_map_voc2():
logger.info("test_cache_map_voc2 Ended.\n")


class ReverseSampler(ds.Sampler):
def __iter__(self):
for i in range(self.dataset_size - 1, -1, -1):
yield i


@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_python_sampler1():
"""
Test using a python sampler, and cache after leaf

Repeat
|
Map(decode)
|
cache
|
ImageFolder
"""

logger.info("Test cache map python sampler1")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")

some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)

# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, sampler=ReverseSampler(), cache=some_cache)
decode_op = c_vision.Decode()
ds1 = ds1.map(input_columns=["image"], operations=decode_op)
ds1 = ds1.repeat(4)

num_iter = 0
for _ in ds1.create_dict_iterator():
num_iter += 1
logger.info("Number of data in ds1: {} ".format(num_iter))
assert num_iter == 8
logger.info("test_cache_map_python_sampler1 Ended.\n")


@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_python_sampler2():
"""
Test using a python sampler, and cache after map

Repeat
|
cache
|
Map(decode)
|
ImageFolder
"""

logger.info("Test cache map python sampler2")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")

some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)

# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, sampler=ReverseSampler())
decode_op = c_vision.Decode()
ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
ds1 = ds1.repeat(4)

num_iter = 0
for _ in ds1.create_dict_iterator():
num_iter += 1
logger.info("Number of data in ds1: {} ".format(num_iter))
assert num_iter == 8
logger.info("test_cache_map_python_sampler2 Ended.\n")


@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_nested_repeat():
"""
Test cache on pipeline with nested repeat ops

Repeat
|
Map(decode)
|
Repeat
|
Cache
|
ImageFolder
"""

logger.info("Test cache map nested repeat")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")

some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)

# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
decode_op = c_vision.Decode()
ds1 = ds1.repeat(4)
ds1 = ds1.map(operations=decode_op, input_columns=["image"])
ds1 = ds1.repeat(2)

num_iter = 0
for _ in ds1.create_dict_iterator(num_epochs=1):
logger.info("get data from dataset")
num_iter += 1

logger.info("Number of data in ds1: {} ".format(num_iter))
assert num_iter == 16
logger.info('test_cache_map_nested_repeat Ended.\n')


if __name__ == '__main__':
test_cache_map_basic1()
test_cache_map_basic2()


+ 85
- 0
tests/ut/python/dataset/test_cache_nomap.py View File

@@ -1292,6 +1292,50 @@ def test_cache_nomap_epoch_ctrl3():
logger.info("test_cache_nomap_epoch_ctrl3 Ended.\n")


@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_epoch_ctrl4():
"""
Test using two-loops method with repeat under cache

cache
|
Map(decode)
|
repeat
|
TFRecord
"""

logger.info("Test cache nomap epoch ctrl4")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")

some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)

# This dataset has 3 records in it only
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
ds1 = ds1.repeat(2)
decode_op = c_vision.Decode()
ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)

num_epoch = 5
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)

epoch_count = 0
for _ in range(num_epoch):
row_count = 0
for _ in iter1:
row_count += 1
logger.info("Number of data in ds1: {} ".format(row_count))
assert row_count == 6
epoch_count += 1
assert epoch_count == num_epoch

logger.info("test_cache_nomap_epoch_ctrl4 Ended.\n")


@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_multiple_cache1():
"""
@@ -1734,6 +1778,47 @@ def test_cache_nomap_textfile2():
logger.info("test_cache_nomap_textfile2 Ended.\n")


@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_nested_repeat():
"""
Test cache on pipeline with nested repeat ops

Repeat
|
Cache
|
Map(decode)
|
Repeat
|
TFRecord
"""

logger.info("Test cache nomap nested repeat")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")

some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)

# This dataset has 3 records in it only
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
decode_op = c_vision.Decode()
ds1 = ds1.repeat(4)
ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
ds1 = ds1.repeat(2)

num_iter = 0
for _ in ds1.create_dict_iterator(num_epochs=1):
logger.info("get data from dataset")
num_iter += 1

logger.info("Number of data in ds1: {} ".format(num_iter))
assert num_iter == 24
logger.info('test_cache_nomap_nested_repeat Ended.\n')


if __name__ == '__main__':
test_cache_nomap_basic1()
test_cache_nomap_basic2()


+ 10
- 0
tests/ut/python/test_server_stop_testcase.sh View File

@@ -0,0 +1,10 @@
~/cache/cache_admin --start
session_id=$(~/cache/cache_admin -g | awk '{print $NF}')
export SESSION_ID=${session_id}
pytest dataset/test_cache_nomap.py::test_cache_nomap_server_stop &
pid=("$!")

sleep 2
~/cache/cache_admin --stop
sleep 1
wait ${pid}

Loading…
Cancel
Save