Merge pull request !2891 from Jamie/CacheOp_devtags/v0.6.0-beta
| @@ -47,6 +47,8 @@ include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc/dataset/include) | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ORIGIN:$ORIGIN/lib") | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=default") | |||
| ms_build_flatbuffers("engine/cache/de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${CMAKE_BINARY_DIR}) | |||
| ################## Include sub-modules ############################### | |||
| add_subdirectory(util) | |||
| add_subdirectory(core) | |||
| @@ -55,7 +57,7 @@ add_subdirectory(engine) | |||
| add_subdirectory(api) | |||
| add_subdirectory(text) | |||
| ###################################################################### | |||
| add_dependencies(core utils) | |||
| add_dependencies(utils core) | |||
| add_dependencies(kernels-image core) | |||
| add_dependencies(kernels-data core) | |||
| add_dependencies(kernels core) | |||
| @@ -89,6 +91,8 @@ set(submodules | |||
| $<TARGET_OBJECTS:engine-perf> | |||
| $<TARGET_OBJECTS:engine-datasetops> | |||
| $<TARGET_OBJECTS:engine-opt> | |||
| $<TARGET_OBJECTS:engine-cache-client> | |||
| $<TARGET_OBJECTS:engine-cache-server> | |||
| $<TARGET_OBJECTS:engine> | |||
| $<TARGET_OBJECTS:text> | |||
| $<TARGET_OBJECTS:text-kernels> | |||
| @@ -106,6 +110,8 @@ else () | |||
| add_library(_c_dataengine SHARED ${submodules}) | |||
| endif () | |||
| add_dependencies(_c_dataengine generated_engine_files) | |||
| set_target_properties(_c_dataengine PROPERTIES | |||
| PREFIX "${PYTHON_MODULE_PREFIX}" | |||
| SUFFIX "${PYTHON_MODULE_EXTENSION}" | |||
| @@ -21,8 +21,10 @@ | |||
| #include "common/utils.h" | |||
| #include "dataset/core/tensor.h" | |||
| #include "dataset/engine/cache/cache_client.h" | |||
| #include "dataset/engine/dataset_iterator.h" | |||
| #include "dataset/engine/datasetops/bucket_batch_by_length_op.h" | |||
| #include "dataset/engine/datasetops/cache_op.h" | |||
| #include "dataset/engine/datasetops/filter_op.h" | |||
| #include "dataset/engine/datasetops/source/celeba_op.h" | |||
| #include "dataset/engine/datasetops/source/cifar_op.h" | |||
| @@ -34,6 +36,7 @@ | |||
| #include "dataset/engine/datasetops/source/random_data_op.h" | |||
| #include "dataset/engine/datasetops/source/text_file_op.h" | |||
| #include "dataset/engine/datasetops/source/voc_op.h" | |||
| #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" | |||
| #include "dataset/kernels/py_func_op.h" | |||
| #include "dataset/util/random.h" | |||
| #include "dataset/util/status.h" | |||
| @@ -441,6 +444,8 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> * | |||
| MapOp::Builder map_builder; | |||
| std::vector<std::shared_ptr<TensorOp>> tensor_op_list; | |||
| std::vector<std::string> project_columns; | |||
| std::shared_ptr<CacheClient> cache_client = nullptr; | |||
| int num_workers = 0; | |||
| if (args["operations"].is_none()) RETURN_STATUS_UNEXPECTED("Error: 'operations' is not set. \n"); | |||
| @@ -456,7 +461,8 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> * | |||
| } else if (key == "columns_order") { | |||
| project_columns = ToStringVector(value); | |||
| } else if (key == "num_parallel_workers") { | |||
| (void)map_builder.SetNumWorkers(ToInt(value)); | |||
| num_workers = ToInt(value); | |||
| (void)map_builder.SetNumWorkers(num_workers); | |||
| } else if (key == "prefetch_size") { | |||
| (void)map_builder.SetOpConnectorSize(ToInt(value)); | |||
| } else if (key == "operations") { | |||
| @@ -477,6 +483,8 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> * | |||
| } | |||
| if (tensor_op_list.empty()) RETURN_STATUS_UNEXPECTED("Error: tensor_op is invalid or not set."); | |||
| (void)map_builder.SetTensorFuncs(std::move(tensor_op_list)); | |||
| } else if (key == "cache") { | |||
| cache_client = value.cast<std::shared_ptr<CacheClient>>(); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Error: Unhandled key: " + key); | |||
| } | |||
| @@ -499,6 +507,15 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> * | |||
| *bottom = map_op; | |||
| } | |||
| // Additionally, add a cache if required. This will go over top of the project op if one | |||
| // was created, otherwise it goes over top of the map op | |||
| if (cache_client) { | |||
| std::shared_ptr<DatasetOp> cache_op = nullptr; | |||
| RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, *top, &cache_op)); | |||
| *top = cache_op; | |||
| *bottom = map_op; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -809,6 +826,9 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset | |||
| std::shared_ptr<DatasetOp> *bottom) { | |||
| // Required arguments | |||
| std::vector<std::string> files_list; | |||
| std::shared_ptr<CacheClient> cache_client = nullptr; | |||
| std::shared_ptr<Sampler> sampler = nullptr; | |||
| int num_workers = 0; | |||
| std::shared_ptr<TFReaderOp::Builder> builder = std::make_shared<TFReaderOp::Builder>(); | |||
| if (!args["dataset_files"].is_none()) { | |||
| files_list = ToStringVector(args["dataset_files"]); | |||
| @@ -828,7 +848,8 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset | |||
| py::handle value = arg.second; | |||
| if (!value.is_none()) { | |||
| if (key == "num_parallel_workers") { | |||
| (void)builder->SetNumWorkers(ToInt(value)); | |||
| num_workers = ToInt(value); | |||
| (void)builder->SetNumWorkers(num_workers); | |||
| } else if (key == "columns_list") { | |||
| columns_to_load = ToStringVector(value); | |||
| (void)builder->SetColumnsToLoad(columns_to_load); | |||
| @@ -848,6 +869,11 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset | |||
| (void)builder->SetDeviceId(ToInt(value)); | |||
| } else if (key == "shard_equal_rows") { | |||
| (void)builder->SetShardEqualRows(ToBool(value)); | |||
| } else if (key == "cache") { | |||
| cache_client = value.cast<std::shared_ptr<CacheClient>>(); | |||
| } else if (key == "sampler") { | |||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||
| sampler = create().cast<std::shared_ptr<Sampler>>(); | |||
| } | |||
| } | |||
| } | |||
| @@ -860,12 +886,27 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset | |||
| } | |||
| (void)builder->SetDataSchema(std::move(schema)); | |||
| } | |||
| // If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed | |||
| // because TFReaderOp is a non-mappable dataset that does not support sampling. | |||
| // However, if a cache operator is injected at some other place higher in the tree, that cache can | |||
| // inherit this sampler from the leaf, providing sampling support from the caching layer. | |||
| // That is why we save the sampler here in a leaf node that does not use sampling. | |||
| if (sampler) { | |||
| (void)builder->SetSampler(std::move(sampler)); | |||
| } else if (cache_client) { | |||
| int64_t num_samples = 0; | |||
| int64_t start_index = 0; | |||
| sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | |||
| (void)builder->SetSampler(std::move(sampler)); | |||
| } | |||
| std::shared_ptr<TFReaderOp> tf_op; | |||
| RETURN_IF_NOT_OK(builder->Build(&tf_op)); | |||
| RETURN_IF_NOT_OK(tree_->AssociateNode(tf_op)); | |||
| *top = tf_op; | |||
| if (shuffle_required) { | |||
| if (!cache_client && shuffle_required) { | |||
| const boolean estimate = true; | |||
| const int64_t workers = 8; | |||
| std::shared_ptr<DatasetOp> shuffle_op = nullptr; | |||
| @@ -882,6 +923,15 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset | |||
| *bottom = tf_op; | |||
| } | |||
| // Add a cache op over this op if required and update the output subtree (top/bottom) | |||
| if (cache_client) { | |||
| // Note, it is not allowed to have both shuffle and cache | |||
| std::shared_ptr<DatasetOp> cache_op = nullptr; | |||
| RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, tf_op, &cache_op)); | |||
| *top = cache_op; | |||
| *bottom = tf_op; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -906,6 +956,8 @@ Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr<Data | |||
| std::string err_msg = "Error: No dataset path specified"; | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| int num_workers = 0; | |||
| std::shared_ptr<CacheClient> cache_client = nullptr; | |||
| std::shared_ptr<ImageFolderOp::Builder> builder = std::make_shared<ImageFolderOp::Builder>(); | |||
| (void)builder->SetImageFolderDir(ToString(args["dataset_dir"])); | |||
| @@ -915,7 +967,8 @@ Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr<Data | |||
| py::handle value = arg.second; | |||
| if (!value.is_none()) { | |||
| if (key == "num_parallel_workers") { | |||
| (void)builder->SetNumWorkers(ToInt(value)); | |||
| num_workers = ToInt(value); | |||
| (void)builder->SetNumWorkers(num_workers); | |||
| } else if (key == "sampler") { | |||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | |||
| @@ -926,12 +979,27 @@ Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr<Data | |||
| (void)builder->SetClassIndex(ToStringMap(value)); | |||
| } else if (key == "decode") { | |||
| (void)builder->SetDecode(ToBool(value)); | |||
| } else if (key == "cache") { | |||
| cache_client = value.cast<std::shared_ptr<CacheClient>>(); | |||
| } | |||
| } | |||
| } | |||
| std::shared_ptr<ImageFolderOp> op; | |||
| RETURN_IF_NOT_OK(builder->Build(&op)); | |||
| *top = op; | |||
| std::shared_ptr<ImageFolderOp> if_op; | |||
| RETURN_IF_NOT_OK(builder->Build(&if_op)); | |||
| RETURN_IF_NOT_OK(tree_->AssociateNode(if_op)); | |||
| *top = if_op; | |||
| // Additionally, add a cache if required. | |||
| // Note that this cache op is only acting as a place holder for the caching position | |||
| // within the tree. Later, a pre-pass will execute a tree transform to set up the actual | |||
| // caching logic in the tree. | |||
| if (cache_client) { | |||
| std::shared_ptr<DatasetOp> cache_op = nullptr; | |||
| RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, if_op, &cache_op)); | |||
| *top = cache_op; | |||
| *bottom = if_op; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -1130,9 +1198,12 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas | |||
| std::shared_ptr<DatasetOp> *bottom) { | |||
| // Required arguments | |||
| RandomDataOp::Builder builder; | |||
| std::shared_ptr<CacheClient> cache_client = nullptr; | |||
| std::shared_ptr<Sampler> sampler = nullptr; | |||
| int num_workers = 0; | |||
| if (args["num_samples"].is_none()) { | |||
| std::string err_msg = "Error: num_samples is a required argument"; | |||
| if (args["total_rows"].is_none()) { | |||
| std::string err_msg = "Error: total_rows is a required argument"; | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| std::vector<std::string> columns_to_load; | |||
| @@ -1141,16 +1212,23 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas | |||
| for (auto arg : args) { | |||
| std::string key = py::str(arg.first); | |||
| py::handle value = arg.second; | |||
| if (key == "num_parallel_workers") { | |||
| (void)builder.SetNumWorkers(ToInt(value)); | |||
| } else if (key == "schema_file_path" || key == "schema_json_string") { | |||
| schema_exists = true; | |||
| } else if (key == "columns_list") { | |||
| columns_to_load = ToStringVector(value); | |||
| } else if (key == "num_samples") { | |||
| // This is not sampling here. The random data op needs to know how much data to | |||
| // generate. It does not currently support sampling. | |||
| (void)builder.SetTotalRows(ToInt(value)); | |||
| if (!value.is_none()) { | |||
| if (key == "num_parallel_workers") { | |||
| num_workers = ToInt(value); | |||
| (void)builder.SetNumWorkers(num_workers); | |||
| } else if (key == "schema_file_path" || key == "schema_json_string") { | |||
| schema_exists = true; | |||
| } else if (key == "columns_list") { | |||
| columns_to_load = ToStringVector(value); | |||
| } else if (key == "total_rows") { | |||
| // This is not sampling here. The random data op needs to know how much data to generate. | |||
| (void)builder.SetTotalRows(ToInt(value)); | |||
| } else if (key == "cache") { | |||
| cache_client = value.cast<std::shared_ptr<CacheClient>>(); | |||
| } else if (key == "sampler") { | |||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||
| sampler = create().cast<std::shared_ptr<Sampler>>(); | |||
| } | |||
| } | |||
| } | |||
| if (schema_exists) { | |||
| @@ -1162,9 +1240,34 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas | |||
| } | |||
| (void)builder.SetDataSchema(std::move(schema)); | |||
| } | |||
| std::shared_ptr<RandomDataOp> op; | |||
| RETURN_IF_NOT_OK(builder.Build(&op)); | |||
| *top = op; | |||
| // If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed | |||
| // because RandomDataOp is a non-mappable dataset that does not support sampling. | |||
| // However, if a cache operator is injected at some other place higher in the tree, that cache can | |||
| // inherit this sampler from the leaf, providing sampling support from the caching layer. | |||
| // That is why we save the sampler here in a leaf node that does not use sampling. | |||
| if (sampler) { | |||
| (void)builder.SetSampler(std::move(sampler)); | |||
| } else if (cache_client) { | |||
| int64_t num_samples = 0; | |||
| int64_t start_index = 0; | |||
| sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | |||
| (void)builder.SetSampler(std::move(sampler)); | |||
| } | |||
| std::shared_ptr<RandomDataOp> random_op = nullptr; | |||
| RETURN_IF_NOT_OK(builder.Build(&random_op)); | |||
| RETURN_IF_NOT_OK(tree_->AssociateNode(random_op)); | |||
| *top = random_op; | |||
| // Add a cache op over this op if required and update the output subtree (top/bottom) | |||
| if (cache_client) { | |||
| std::shared_ptr<DatasetOp> cache_op = nullptr; | |||
| RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, random_op, &cache_op)); | |||
| *top = cache_op; | |||
| *bottom = random_op; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -1425,6 +1528,31 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||
| return Status::OK(); | |||
| } | |||
| // Helper function to inject the cache operator over top of the current operation being built. | |||
| Status DEPipeline::AddCacheOp(std::shared_ptr<CacheClient> cache_client, int num_workers, | |||
| std::shared_ptr<DatasetOp> input_op, std::shared_ptr<DatasetOp> *cache_op) { | |||
| std::shared_ptr<CacheOp> new_cache_op = nullptr; | |||
| CacheOp::Builder cache_builder; | |||
| // use the same number of workers as the leaf. We need some optimization here, the user does not | |||
| // give the cache op number of workers directly. | |||
| if (num_workers != 0) { | |||
| (void)cache_builder.SetNumWorkers(num_workers); | |||
| } | |||
| (void)cache_builder.SetClient(cache_client); | |||
| RETURN_IF_NOT_OK(cache_builder.Build(&new_cache_op)); | |||
| RETURN_IF_NOT_OK(tree_->AssociateNode(new_cache_op)); | |||
| RETURN_IF_NOT_OK(new_cache_op->AddChild(input_op)); | |||
| // We have now created: | |||
| // | |||
| // CacheOp | |||
| // | | |||
| // input_op | |||
| // | |||
| *cache_op = new_cache_op; | |||
| return Status::OK(); | |||
| } | |||
| // Helper function to inject a shuffle operator over top of the current operation being built. | |||
| Status DEPipeline::AddShuffleOp(int64_t shuffle_size, std::shared_ptr<DatasetOp> input_op, | |||
| std::shared_ptr<DatasetOp> *shuffle_op) { | |||
| @@ -35,6 +35,8 @@ namespace mindspore { | |||
| namespace dataset { | |||
| using DsOpPtr = std::shared_ptr<DatasetOp>; | |||
| class CacheClient; | |||
| // enum for the dataset operator names | |||
| enum OpName { | |||
| kShuffle, | |||
| @@ -181,6 +183,16 @@ class DEPipeline { | |||
| static Status ParsePadInfo(py::handle value, PadInfo *pad_info); | |||
| /// \brief Helper function to inject a cache operator over top of the current operation being built. | |||
| /// \param[in] cache_client The client to use for caching | |||
| /// \param[in] num_workers The number of workers to use in the cache op | |||
| /// \param[in] input_op The operator to build the cache on top of | |||
| /// \param[out] cache_op The top node of the created subtree (subtree contains two nodes). In this case it will be | |||
| /// the cache operator | |||
| /// \return Status return code | |||
| Status AddCacheOp(std::shared_ptr<CacheClient> cache_client, int num_workers, std::shared_ptr<DatasetOp> input_op, | |||
| std::shared_ptr<DatasetOp> *cache_op); | |||
| /// \brief Helper function to inject a shuffle operator over top of the current operation being built. | |||
| /// \param[in] shuffle_size The size to use in the shuffle buffer | |||
| /// \param[in] input_op The operator to build shuffle on top of | |||
| @@ -35,6 +35,7 @@ | |||
| #include "dataset/engine/datasetops/source/text_file_op.h" | |||
| #include "dataset/engine/datasetops/source/tf_reader_op.h" | |||
| #include "dataset/engine/datasetops/source/voc_op.h" | |||
| #include "dataset/engine/cache/cache_client.h" | |||
| #include "dataset/engine/gnn/graph.h" | |||
| #include "dataset/engine/jagged_connector.h" | |||
| #include "dataset/kernels/data/concatenate_op.h" | |||
| @@ -768,6 +769,11 @@ void bindInfoObjects(py::module *m) { | |||
| .def("get_batch_num", &BatchOp::CBatchInfo::get_batch_num); | |||
| } | |||
| void bindCacheClient(py::module *m) { | |||
| (void)py::class_<CacheClient, std::shared_ptr<CacheClient>>(*m, "CacheClient") | |||
| .def(py::init<uint32_t, uint64_t, bool>()); | |||
| } | |||
| void bindVocabObjects(py::module *m) { | |||
| (void)py::class_<Vocab, std::shared_ptr<Vocab>>(*m, "Vocab") | |||
| .def(py::init<>()) | |||
| @@ -939,6 +945,7 @@ PYBIND11_MODULE(_c_dataengine, m) { | |||
| bindSamplerOps(&m); | |||
| bindDatasetOps(&m); | |||
| bindInfoObjects(&m); | |||
| bindCacheClient(&m); | |||
| bindVocabObjects(&m); | |||
| bindGraphData(&m); | |||
| bindDependIcuTokenizerOps(&m); | |||
| @@ -2,6 +2,7 @@ add_subdirectory(datasetops) | |||
| add_subdirectory(opt) | |||
| add_subdirectory(gnn) | |||
| add_subdirectory(perf) | |||
| add_subdirectory(cache) | |||
| if (ENABLE_TDTQUE) | |||
| add_subdirectory(tdt) | |||
| endif () | |||
| @@ -17,7 +18,9 @@ add_library(engine OBJECT | |||
| target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS}) | |||
| if (ENABLE_TDTQUE) | |||
| add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt engine-gnn engine-perf) | |||
| else() | |||
| add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf) | |||
| add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt engine-gnn engine-perf | |||
| engine-cache-client engine-cache-server) | |||
| else () | |||
| add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf | |||
| engine-cache-client engine-cache-server) | |||
| endif () | |||
| @@ -0,0 +1,8 @@ | |||
| file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | |||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | |||
| add_library(engine-cache-client OBJECT | |||
| cache_client.cc | |||
| cache_request.cc) | |||
| add_library(engine-cache-server OBJECT | |||
| cache_service.cc | |||
| cache_server.cc) | |||
| @@ -0,0 +1,208 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <iomanip> | |||
| #include "dataset/engine/cache/cache_client.h" | |||
| #include "dataset/engine/cache/cache_request.h" | |||
| #include "dataset/util/bit.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Constructor | |||
| CacheClient::CacheClient(uint32_t session_id, uint64_t cache_mem_sz, bool spill) | |||
| : server_connection_id_(0), session_id_(session_id), cache_crc_(0), cache_mem_sz_(cache_mem_sz), spill_(spill) {} | |||
| // print method for display cache details | |||
| void CacheClient::Print(std::ostream &out) const { | |||
| out << " Session id: " << session_id_ << "\n Cache crc: " << cache_crc_ | |||
| << "\n Server cache id: " << server_connection_id_ << "\n Cache mem size: " << cache_mem_sz_ | |||
| << "\n Spilling: " << std::boolalpha << spill_; | |||
| } | |||
| Status CacheClient::WriteRow(const TensorRow &row, row_id_type *row_id_from_server) const { | |||
| CacheRowRequest rq(server_connection_id_, cookie()); | |||
| RETURN_IF_NOT_OK(rq.SerializeCacheRowRequest(row)); | |||
| RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); | |||
| RETURN_IF_NOT_OK(rq.Wait()); | |||
| if (row_id_from_server != nullptr) { | |||
| *row_id_from_server = rq.GetRowIdAfterCache(); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheClient::WriteBuffer(std::unique_ptr<DataBuffer> &&in) const { | |||
| std::unique_ptr<DataBuffer> db_ptr = std::move(in); | |||
| auto num_rows = db_ptr->NumRows(); | |||
| std::vector<TensorRow> all_rows; | |||
| if (num_rows > 0) { | |||
| all_rows.reserve(num_rows); | |||
| // Break down the DataBuffer into TensorRow. We will send the requests async | |||
| // and then do a final wait. | |||
| MemGuard<CacheRowRequest> rq_arr; | |||
| RETURN_IF_NOT_OK(rq_arr.allocate(num_rows, server_connection_id_, cookie())); | |||
| CacheServer &cs = CacheServer::GetInstance(); | |||
| for (auto i = 0; i < num_rows; ++i) { | |||
| TensorRow row; | |||
| auto rq = rq_arr[i]; | |||
| RETURN_IF_NOT_OK(db_ptr->PopRow(&row)); | |||
| RETURN_IF_NOT_OK(rq->SerializeCacheRowRequest(row)); | |||
| RETURN_IF_NOT_OK(cs.PushRequest(rq)); | |||
| // We can't let row go out of scope. Otherwise it will free all the tensor memory. | |||
| // So park it in the vector. When this function go out of scope, its memory | |||
| // will be freed. | |||
| all_rows.push_back(std::move(row)); | |||
| } | |||
| // Now we wait for the requests to be done. | |||
| for (auto i = 0; i < num_rows; ++i) { | |||
| auto rq = rq_arr[i]; | |||
| RETURN_IF_NOT_OK(rq->Wait()); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheClient::GetRows(const std::vector<row_id_type> &row_id, TensorTable *out) const { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| BatchFetchRequest rq(server_connection_id_, row_id); | |||
| RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); | |||
| RETURN_IF_NOT_OK(rq.Wait()); | |||
| RETURN_IF_NOT_OK(rq.RestoreRows(out)); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) { | |||
| UniqueLock lck(&mux_); | |||
| // To create a cache, we identify ourself at the client by: | |||
| // - the shared session id | |||
| // - a crc for the tree nodes from the cache downward | |||
| // Pack these 2 into a single 64 bit request id | |||
| // | |||
| // Consider this example: | |||
| // tree1: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> batch | |||
| // tree2: cifar10 --> map(rotate) --> cache (session id = 1, crc = 456) --> batch | |||
| // These are different trees in a single session, but the user wants to share the cache. | |||
| // This is not allowed because the data of these caches are different. | |||
| // | |||
| // Consider this example: | |||
| // tree1: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> batch | |||
| // tree2: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> map(rotate) --> batch | |||
| // These are different trees in the same session, but the cached data is the same, so it is okay | |||
| // to allow the sharing of this cache between these pipelines. | |||
| // The CRC is computed by the tree prepare phase and passed to this function when creating the cache. | |||
| // If we already have a server_connection_id_, then it means this same cache client has already been used | |||
| // to create a cache and some other tree is trying to use the same cache. | |||
| // That is allowed, however the crc better match! | |||
| if (server_connection_id_) { | |||
| if (cache_crc_ != tree_crc) { | |||
| RETURN_STATUS_UNEXPECTED("Attempt to re-use a cache for a different tree!"); | |||
| } | |||
| // Check the state of the server. For non-mappable case where there is a build phase and a fetch phase, we should | |||
| // skip the build phase. | |||
| lck.Unlock(); // GetStat will grab the mutex again. So unlock it to prevent deadlock. | |||
| CacheClient::ServiceStat stat{}; | |||
| RETURN_IF_NOT_OK(GetStat(&stat)); | |||
| if (stat.cache_service_state == static_cast<uint8_t>(CacheService::State::kFetchPhase)) { | |||
| return Status(StatusCode::kDuplicateKey, __LINE__, __FILE__, "Not an error and we should bypass the build phase"); | |||
| } | |||
| } else { | |||
| cache_crc_ = tree_crc; // It's really a new cache we're creating so save our crc in the client | |||
| // Combine the session and crc. This will form our client cache identifier. | |||
| connection_id_type connection_identification = (static_cast<uint64_t>(session_id_) << 32) | cache_crc_; | |||
| // Now execute the cache create request using this identifier and other configs | |||
| BaseRequest::CreateCacheFlag createFlag = BaseRequest::CreateCacheFlag::kNone; | |||
| if (spill_) { | |||
| createFlag |= BaseRequest::CreateCacheFlag::kSpillToDisk; | |||
| } | |||
| if (generate_id) { | |||
| createFlag |= BaseRequest::CreateCacheFlag::kGenerateRowId; | |||
| } | |||
| CreationCacheRequest rq(connection_identification, cache_mem_sz_, createFlag); | |||
| RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); | |||
| Status rc = rq.Wait(); | |||
| if (rc.IsOk() || rc.get_code() == StatusCode::kDuplicateKey) { | |||
| server_connection_id_ = rq.GetServerConnectionId(); | |||
| if (rc.IsOk()) { | |||
| // The 1st guy creating the cache will get a cookie back. | |||
| // But this object may be shared among pipelines and we don't want | |||
| // overwrite it. | |||
| cookie_ = rq.cookie(); | |||
| } | |||
| } | |||
| // We are not resetting the Duplicate key return code. We are passing it back to the CacheOp. This will tell the | |||
| // CacheOp to bypass the build phase. | |||
| return rc; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheClient::PurgeCache() { | |||
| UniqueLock lck(&mux_); | |||
| PurgeCacheRequest rq(server_connection_id_); | |||
| RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); | |||
| return rq.Wait(); | |||
| } | |||
| Status CacheClient::DestroyCache() { | |||
| UniqueLock lck(&mux_); | |||
| DestroyCacheRequest rq(server_connection_id_); | |||
| RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); | |||
| return rq.Wait(); | |||
| } | |||
| Status CacheClient::GetStat(ServiceStat *stat) { | |||
| SharedLock lck(&mux_); | |||
| RETURN_UNEXPECTED_IF_NULL(stat); | |||
| GetStatRequest rq(server_connection_id_); | |||
| RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); | |||
| RETURN_IF_NOT_OK(rq.Wait()); | |||
| stat->num_disk_cached = rq.GetNumDiskCached(); | |||
| stat->num_mem_cached = rq.GetNumMemCached(); | |||
| stat->min_row_id = rq.GetMinRowId(); | |||
| stat->max_row_id = rq.GetMaxRowId(); | |||
| stat->cache_service_state = rq.GetState(); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheClient::CacheSchema(const std::unordered_map<std::string, int32_t> &map) { | |||
| SharedLock lck(&mux_); | |||
| CacheSchemaRequest rq(server_connection_id_); | |||
| RETURN_IF_NOT_OK(rq.SerializeCacheSchemaRequest(map)); | |||
| RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); | |||
| RETURN_IF_NOT_OK(rq.Wait()); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheClient::FetchSchema(std::unordered_map<std::string, int32_t> *map) { | |||
| SharedLock lck(&mux_); | |||
| RETURN_UNEXPECTED_IF_NULL(map); | |||
| FetchSchemaRequest rq(server_connection_id_); | |||
| RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); | |||
| RETURN_IF_NOT_OK(rq.Wait()); | |||
| *map = rq.GetColumnMap(); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheClient::BuildPhaseDone() const { | |||
| SharedLock lck(&mux_); | |||
| BuildPhaseDoneRequest rq(server_connection_id_, cookie()); | |||
| RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); | |||
| RETURN_IF_NOT_OK(rq.Wait()); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,141 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_CACHE_CLIENT_H_ | |||
| #define DATASET_ENGINE_CACHE_CLIENT_H_ | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "./de_tensor_generated.h" | |||
| #include "dataset/engine/data_buffer.h" | |||
| #include "dataset/engine/cache/cache_server.h" | |||
| #include "dataset/util/lock.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// \brief A CacheClient is a bridge between a DatasetOp and a CacheServer. All communications are through | |||
| /// a CacheClient. Typical tasks including like creating a cache service, cache a data buffer, restore a previously | |||
| /// rows, etc. | |||
| class CacheClient { | |||
| public: | |||
| /// \brief Constructor | |||
| /// \param session_id A user assigned session id for the current pipeline | |||
| /// \param cache_mem_sz Size of the memory set aside for the row caching. 0 for unlimited | |||
| /// \param spill Spill to disk if out of memory | |||
| CacheClient(uint32_t session_id, uint64_t cache_mem_sz, bool spill); | |||
| /// \brief Destructor | |||
| ~CacheClient() = default; | |||
| /// \brief Getter function for returning the current session id | |||
| /// \return session id | |||
| uint64_t session_id() const { return session_id_; } | |||
| /// \brief Send a TensorRow to the cache server | |||
| /// \param[in] row | |||
| /// \param[out] row_id_from_server Optional. The row id assigned by the server for non-mappable dataset | |||
| /// \return return code | |||
| Status WriteRow(const TensorRow &row, row_id_type *row_id_from_server = nullptr) const; | |||
| /// \brief Send a DataBuffer to the cache server | |||
| /// \param in Unique pointer of the DataBuffer to be cached | |||
| /// \return return code | |||
| Status WriteBuffer(std::unique_ptr<DataBuffer> &&in) const; | |||
| /// \brief Fetch a list of rows from the cache server. An empty TensorRow will be returned if there is | |||
| /// any cache miss | |||
| /// \param row_id A vector of row id's | |||
| /// \param out A TensorTable of TensorRows. | |||
| /// \return return code | |||
| Status GetRows(const std::vector<row_id_type> &row_id, TensorTable *out) const; | |||
| /// \brief Create a cache. | |||
| /// \param tree_crc A crc that was generated during tree prepare phase | |||
| /// \param generate_id Let the cache service generate row id | |||
| /// \return Status object | |||
| Status CreateCache(uint32_t tree_crc, bool generate_id); | |||
| /// \brief Purge a cache. Cache can be reused after reset. | |||
| /// \return Status object | |||
| Status PurgeCache(); | |||
| /// \brief Destroy a cache. Like Purge but the cache is deleted and can't be reused. | |||
| /// \return Status object | |||
| Status DestroyCache(); | |||
| /// \brief Get the statistics from a cache. | |||
| /// \param[in/out] Pointer to a pre-allocated ServiceStat object | |||
| /// \return Status object | |||
| struct ServiceStat { | |||
| int64_t num_mem_cached; | |||
| int64_t num_disk_cached; | |||
| row_id_type min_row_id; | |||
| row_id_type max_row_id; | |||
| int8_t cache_service_state; | |||
| }; | |||
| Status GetStat(ServiceStat *); | |||
| /// \brief Cache the schema at the cache server | |||
| /// \param map The unordered map of the schema | |||
| /// \return Status object | |||
| Status CacheSchema(const std::unordered_map<std::string, int32_t> &map); | |||
| /// \brief Fetch the schema from the cache server | |||
| /// \param map Pointer to pre-allocated map object | |||
| /// \return Status object. | |||
| Status FetchSchema(std::unordered_map<std::string, int32_t> *map); | |||
| /// \brief Change the state from build phase to read phase. Applicable to non-mappable dataset only. Only the cache | |||
| /// client that holds cookie can be allowed to make this request | |||
| /// \return Status object | |||
| Status BuildPhaseDone() const; | |||
| /// \brief A print method typically used for debugging | |||
| /// \param out The output stream to write output to | |||
| void Print(std::ostream &out) const; | |||
| /// \brief Stream output operator overload | |||
| /// \return the output stream must be returned | |||
| friend std::ostream &operator<<(std::ostream &out, const CacheClient &cc) { | |||
| cc.Print(out); | |||
| return out; | |||
| } | |||
| /// \brief Every cache server has a cookie which uniquely identifies the CacheClient that creates it. | |||
| /// \return Cookie | |||
| std::string cookie() const { return cookie_; } | |||
| private: | |||
| mutable RWLock mux_; | |||
| uint64_t cache_mem_sz_; | |||
| bool spill_; | |||
| // The session_id_ and cache_crc_ work together to uniquely identify this particular cache and allow | |||
| // sharing of the cache. | |||
| uint32_t session_id_; | |||
| uint32_t cache_crc_; | |||
| // The server_connection_id_ is the actual id we use for operations after the cache is built | |||
| connection_id_type server_connection_id_; | |||
| // Some magic cookie returned from the cache server. | |||
| std::string cookie_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_CACHE_CLIENT_H_ | |||
| @@ -0,0 +1,223 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/engine/cache/cache_request.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| Status CacheRowRequest::SerializeCacheRowRequest(const TensorRow &row) { | |||
| buffers_.reserve(row.size() + 1); | |||
| RETURN_IF_NOT_OK(SerializeTensorRowHeader(row)); | |||
| buffers_.push_back(fbb_->GetBufferPointer()); | |||
| for (const auto &ts : row) { | |||
| buffers_.push_back(ts->GetBuffer()); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheRowRequest::SerializeTensorRowHeader(const TensorRow &row) { | |||
| try { | |||
| fbb_ = std::make_shared<flatbuffers::FlatBufferBuilder>(); | |||
| std::vector<flatbuffers::Offset<TensorMetaMsg>> v; | |||
| std::vector<int64_t> tensor_sz; | |||
| v.reserve(row.size()); | |||
| tensor_sz.reserve(row.size()); | |||
| // We will go through each column in the row. | |||
| for (const std::shared_ptr<Tensor> &ts_ptr : row) { | |||
| flatbuffers::Offset<TensorMetaMsg> ts_off; | |||
| RETURN_IF_NOT_OK(SerializeOneTensorMeta(ts_ptr, &ts_off)); | |||
| v.push_back(ts_off); | |||
| tensor_sz.push_back(ts_ptr->SizeInBytes()); | |||
| } | |||
| auto column_off = fbb_->CreateVector(v); | |||
| auto data_sz_off = fbb_->CreateVector(tensor_sz); | |||
| TensorRowHeaderMsgBuilder row_builder(*fbb_); | |||
| row_builder.add_column(column_off); | |||
| row_builder.add_data_sz(data_sz_off); | |||
| // Pass the row_id even if it may not be known. | |||
| row_builder.add_row_id(row.getId()); | |||
| row_builder.add_size_of_this(-1); // fill in later after we call Finish. | |||
| auto out = row_builder.Finish(); | |||
| fbb_->Finish(out); | |||
| // Now go back to fill in size_of_this in the flat buffer. | |||
| auto msg = GetMutableTensorRowHeaderMsg(fbb_->GetBufferPointer()); | |||
| auto success = msg->mutate_size_of_this(fbb_->GetSize()); | |||
| if (!success) { | |||
| RETURN_STATUS_UNEXPECTED("Unable to set size_of_this"); | |||
| } | |||
| return Status::OK(); | |||
| } catch (const std::bad_alloc &e) { | |||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); | |||
| } | |||
| } | |||
| Status CacheRowRequest::SerializeOneTensorMeta(const std::shared_ptr<Tensor> &ts_ptr, | |||
| flatbuffers::Offset<TensorMetaMsg> *out_off) { | |||
| RETURN_UNEXPECTED_IF_NULL(out_off); | |||
| const Tensor *ts = ts_ptr.get(); | |||
| auto shape_off = fbb_->CreateVector(ts->shape().AsVector()); | |||
| const auto ptr = ts->GetBuffer(); | |||
| if (ptr == nullptr) { | |||
| RETURN_STATUS_UNEXPECTED("Tensor buffer is null"); | |||
| } | |||
| auto src = ts->type().value(); | |||
| TensorType dest; | |||
| #define CASE(t) \ | |||
| case DataType::t: \ | |||
| dest = TensorType::TensorType_##t; \ | |||
| break | |||
| // Map the type to fill in the flat buffer. | |||
| switch (src) { | |||
| CASE(DE_BOOL); | |||
| CASE(DE_INT8); | |||
| CASE(DE_UINT8); | |||
| CASE(DE_INT16); | |||
| CASE(DE_UINT16); | |||
| CASE(DE_INT32); | |||
| CASE(DE_UINT32); | |||
| CASE(DE_INT64); | |||
| CASE(DE_UINT64); | |||
| CASE(DE_FLOAT16); | |||
| CASE(DE_FLOAT32); | |||
| CASE(DE_FLOAT64); | |||
| CASE(DE_STRING); | |||
| default: | |||
| MS_LOG(ERROR) << "Unknown tensor. Dumping content:\n" << *ts; | |||
| RETURN_STATUS_UNEXPECTED("Unknown type"); | |||
| } | |||
| #undef CASE | |||
| TensorMetaMsgBuilder ts_builder(*fbb_); | |||
| ts_builder.add_dims(shape_off); | |||
| ts_builder.add_type(dest); | |||
| auto ts_off = ts_builder.Finish(); | |||
| *out_off = ts_off; | |||
| return Status::OK(); | |||
| } | |||
| Status BatchFetchRequest::RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, | |||
| std::shared_ptr<Tensor> *out) { | |||
| RETURN_UNEXPECTED_IF_NULL(col_ts); | |||
| auto shape_in = col_ts->dims(); | |||
| auto type_in = col_ts->type(); | |||
| std::vector<dsize_t> v; | |||
| v.reserve(shape_in->size()); | |||
| v.assign(shape_in->begin(), shape_in->end()); | |||
| TensorShape shape(v); | |||
| DataType::Type dest = DataType::DE_UNKNOWN; | |||
| #define CASE(t) \ | |||
| case TensorType_##t: \ | |||
| dest = DataType::Type::t; \ | |||
| break | |||
| switch (type_in) { | |||
| CASE(DE_BOOL); | |||
| CASE(DE_INT8); | |||
| CASE(DE_UINT8); | |||
| CASE(DE_INT16); | |||
| CASE(DE_UINT16); | |||
| CASE(DE_INT32); | |||
| CASE(DE_UINT32); | |||
| CASE(DE_INT64); | |||
| CASE(DE_UINT64); | |||
| CASE(DE_FLOAT16); | |||
| CASE(DE_FLOAT32); | |||
| CASE(DE_FLOAT64); | |||
| CASE(DE_STRING); | |||
| } | |||
| #undef CASE | |||
| DataType type(dest); | |||
| std::shared_ptr<Tensor> ts = | |||
| std::make_shared<Tensor>(shape, type, static_cast<const unsigned char *>(data.GetPointer()), data.GetSize()); | |||
| // Next we restore the real data which can be embedded or stored separately. | |||
| if (ts->SizeInBytes() != data.GetSize()) { | |||
| MS_LOG(ERROR) << "Unexpected length. Read " << data.GetSize() << ". Expected " << ts->SizeInBytes() << ".\n" | |||
| << "Dumping tensor\n" | |||
| << *ts << "\n"; | |||
| RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details."); | |||
| } | |||
| *out = std::move(ts); | |||
| return Status::OK(); | |||
| } | |||
| Status BatchFetchRequest::RestoreRows(TensorTable *out) { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| auto num_elements = row_id_.size(); | |||
| auto *offset_array = reinterpret_cast<const int64_t *>(mem_.GetPointer()); | |||
| TensorTable tbl; | |||
| tbl.reserve(num_elements); | |||
| ReadableSlice all(mem_.GetPointer(), mem_.GetSizeInBytes()); | |||
| for (auto i = 0; i < num_elements; ++i) { | |||
| auto len = offset_array[i + 1] - offset_array[i]; | |||
| TensorRow row; | |||
| row.setId(row_id_.at(i)); | |||
| if (len > 0) { | |||
| ReadableSlice row_data(all, offset_array[i], len); | |||
| // Next we de-serialize flat buffer to get back each column | |||
| auto msg = GetTensorRowHeaderMsg(row_data.GetPointer()); | |||
| auto msg_sz = msg->size_of_this(); | |||
| // Start of the tensor data | |||
| auto ts_offset = msg_sz; | |||
| row.reserve(msg->column()->size()); | |||
| for (auto k = 0; k < msg->column()->size(); ++k) { | |||
| auto col_ts = msg->column()->Get(k); | |||
| std::shared_ptr<Tensor> ts; | |||
| ReadableSlice data(row_data, ts_offset, msg->data_sz()->Get(k)); | |||
| RETURN_IF_NOT_OK(RestoreOneTensor(col_ts, data, &ts)); | |||
| row.push_back(ts); | |||
| ts_offset += data.GetSize(); | |||
| } | |||
| } | |||
| tbl.push_back(std::move(row)); | |||
| } | |||
| *out = std::move(tbl); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheSchemaRequest::SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map) { | |||
| try { | |||
| fbb_ = std::make_shared<flatbuffers::FlatBufferBuilder>(); | |||
| std::vector<flatbuffers::Offset<ColumnNameMsg>> v; | |||
| v.reserve(map.size()); | |||
| for (auto &column : map) { | |||
| auto c = CreateColumnNameMsg(*fbb_, fbb_->CreateString(column.first), column.second); | |||
| v.push_back(c); | |||
| } | |||
| auto v_off = fbb_->CreateVector(v); | |||
| auto final_off = CreateSchemaMsg(*fbb_, v_off); | |||
| fbb_->Finish(final_off); | |||
| buf_ = fbb_->GetBufferPointer(); | |||
| len_of_buf_ = fbb_->GetSize(); | |||
| return Status::OK(); | |||
| } catch (const std::bad_alloc &e) { | |||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); | |||
| } | |||
| } | |||
| std::unordered_map<std::string, int32_t> FetchSchemaRequest::GetColumnMap() { | |||
| if (column_name_id_map_.empty()) { | |||
| auto *map_msg = flatbuffers::GetRoot<SchemaMsg>(mem_.GetPointer()); | |||
| auto v = map_msg->column(); | |||
| for (auto i = 0; i < v->size(); ++i) { | |||
| auto col = map_msg->column()->Get(i); | |||
| column_name_id_map_.emplace(col->name()->str(), col->id()); | |||
| } | |||
| } | |||
| return column_name_id_map_; | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,225 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_CACHE_REQ_H_ | |||
| #define DATASET_ENGINE_CACHE_REQ_H_ | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "./de_tensor_generated.h" | |||
| #include "dataset/core/tensor_row.h" | |||
| #include "dataset/util/slice.h" | |||
| #include "dataset/util/wait_post.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// \brief CacheClient communicates with CacheServer using Requests. | |||
| class BaseRequest { | |||
| public: | |||
| // Request types | |||
| enum class RequestType : int16_t { | |||
| kCacheRow = 0, | |||
| kBatchFetchRows = 1, | |||
| kCreateCache = 2, | |||
| kPurgeCache = 3, | |||
| kDestroyCache = 4, | |||
| kGetStat = 5, | |||
| kCacheSchema = 6, | |||
| kFetchSchema = 7, | |||
| kBuildPhaseDone = 8, | |||
| // Add new request before it. | |||
| kRequestUnknown = 32767 | |||
| }; | |||
| // For kCreateCache | |||
| enum class CreateCacheFlag : uint32_t { kNone = 0, kSpillToDisk = 1, kGenerateRowId = 1u << 1L }; | |||
| friend class CacheServer; | |||
| /// \brief Base class of a cache server request | |||
| /// \param connection_id A combination of session id and crc that uniquely identifies a connection. | |||
| /// \param type Type of the request | |||
| explicit BaseRequest(connection_id_type connection_id, RequestType type) | |||
| : type_(type), connection_id_(connection_id) {} | |||
| virtual ~BaseRequest() = default; | |||
| /// \brief Wait for the completion of a request | |||
| /// \return Status returned from the cache server | |||
| Status Wait() { | |||
| RETURN_IF_NOT_OK(wp_.Wait()); | |||
| return rc_; | |||
| } | |||
| /// \brief Getter function of the current connection id | |||
| /// \return Connection id | |||
| connection_id_type GetServerConnectionId() const { return connection_id_; } | |||
| private: | |||
| RequestType type_; | |||
| connection_id_type connection_id_; | |||
| Status rc_; | |||
| WaitPost wp_; | |||
| }; | |||
| /// \brief Request to cache a single TensorRow | |||
| class CacheRowRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| explicit CacheRowRequest(connection_id_type connection_id, const std::string &cookie) | |||
| : BaseRequest(connection_id, RequestType::kCacheRow), row_id_from_server_(-1), cookie_(cookie) {} | |||
| ~CacheRowRequest() = default; | |||
| /// \brief Serialize a TensorRow for streaming to the cache server | |||
| /// \param row TensorRow | |||
| /// \return Status object | |||
| Status SerializeCacheRowRequest(const TensorRow &row); | |||
| /// \brief Return the row id assigned to this row for non-mappable dataset | |||
| /// \return row id of the cached row | |||
| row_id_type GetRowIdAfterCache() { return row_id_from_server_; } | |||
| private: | |||
| std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb_; | |||
| row_id_type row_id_from_server_; | |||
| std::vector<const void *> buffers_; | |||
| std::string cookie_; | |||
| /// \brief Private function to serialize one TensorRow | |||
| /// \param row TensorRow | |||
| /// \return Status object | |||
| Status SerializeTensorRowHeader(const TensorRow &row); | |||
| /// \brief Private function to serialize one Tensor | |||
| /// \param ts_ptr Tensor | |||
| /// \return Status object | |||
| Status SerializeOneTensorMeta(const std::shared_ptr<Tensor> &ts_ptr, flatbuffers::Offset<TensorMetaMsg> *out_off); | |||
| }; | |||
| /// \brief Request to fetch rows in batch | |||
| class BatchFetchRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| friend class CacheService; | |||
| BatchFetchRequest(connection_id_type connection_id, const std::vector<row_id_type> &row_id) | |||
| : BaseRequest(connection_id, RequestType::kBatchFetchRows), row_id_(row_id) {} | |||
| Status RestoreRows(TensorTable *out); | |||
| private: | |||
| std::vector<row_id_type> row_id_; | |||
| MemGuard<uint8_t> mem_; | |||
| Status RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, std::shared_ptr<Tensor> *out); | |||
| }; | |||
| /// \brief Request to create a cache for the current connection | |||
| class CreationCacheRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| /// \brief Constructor | |||
| /// \param connection_id | |||
| /// \param cache_mem_sz Maximum memory assigned for this connection. 0 means unlimited | |||
| /// \param flag Attributes of the cache. | |||
| explicit CreationCacheRequest(connection_id_type connection_id, uint64_t cache_mem_sz, | |||
| CreateCacheFlag flag = CreateCacheFlag::kNone) | |||
| : BaseRequest(connection_id, RequestType::kCreateCache), cache_mem_sz(cache_mem_sz), flag_(flag) {} | |||
| std::string cookie() const { return cookie_; } | |||
| private: | |||
| uint64_t cache_mem_sz; | |||
| CreateCacheFlag flag_; | |||
| std::string cookie_; | |||
| }; | |||
| /// \brief Request to purge a cache. | |||
| class PurgeCacheRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| explicit PurgeCacheRequest(connection_id_type connection_id) : BaseRequest(connection_id, RequestType::kPurgeCache) {} | |||
| }; | |||
| /// \brief Request to destroy a cache | |||
| class DestroyCacheRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| explicit DestroyCacheRequest(connection_id_type connection_id) | |||
| : BaseRequest(connection_id, RequestType::kDestroyCache) {} | |||
| }; | |||
| /// \brief Obtain the statistics of the current connection | |||
| class GetStatRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| friend class CacheService; | |||
| explicit GetStatRequest(connection_id_type connection_id) : BaseRequest(connection_id, RequestType::kGetStat) {} | |||
| row_id_type GetMinRowId() const { | |||
| auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer()); | |||
| return msg->min_row_id(); | |||
| } | |||
| row_id_type GetMaxRowId() const { | |||
| auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer()); | |||
| return msg->max_row_id(); | |||
| } | |||
| int64_t GetNumMemCached() const { | |||
| auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer()); | |||
| return msg->num_mem_cached(); | |||
| } | |||
| int64_t GetNumDiskCached() const { | |||
| auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer()); | |||
| return msg->num_disk_cached(); | |||
| } | |||
| uint8_t GetState() const { | |||
| auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer()); | |||
| return msg->state(); | |||
| } | |||
| private: | |||
| MemGuard<uint8_t> mem_; | |||
| }; | |||
| /// \brief Request to cache a schema | |||
| class CacheSchemaRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| explicit CacheSchemaRequest(connection_id_type connection_id) | |||
| : BaseRequest(connection_id, RequestType::kCacheSchema), buf_(nullptr), len_of_buf_(0) {} | |||
| ~CacheSchemaRequest() = default; | |||
| Status SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map); | |||
| const void *GetBuffer() const { return buf_; } | |||
| private: | |||
| std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb_; | |||
| const void *buf_; | |||
| int64_t len_of_buf_; | |||
| }; | |||
| /// \brief Request to fetch a schema | |||
| class FetchSchemaRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| explicit FetchSchemaRequest(connection_id_type connection_id) | |||
| : BaseRequest(connection_id, RequestType::kFetchSchema) {} | |||
| ~FetchSchemaRequest() = default; | |||
| std::unordered_map<std::string, int32_t> GetColumnMap(); | |||
| private: | |||
| MemGuard<uint8_t> mem_; | |||
| std::unordered_map<std::string, int32_t> column_name_id_map_; | |||
| }; | |||
| /// \brief Request to change a cache from build phase to read phase. Applies to non-mappable cache only. | |||
| class BuildPhaseDoneRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| BuildPhaseDoneRequest(connection_id_type connection_id, const std::string &cookie) | |||
| : BaseRequest(connection_id, RequestType::kBuildPhaseDone), cookie_(cookie) {} | |||
| private: | |||
| std::string cookie_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_CACHE_SERVICE_H_ | |||
| @@ -0,0 +1,252 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/engine/cache/cache_server.h" | |||
| #include "dataset/engine/cache/cache_service.h" | |||
| #include "dataset/engine/cache/cache_request.h" | |||
| #include "dataset/util/bit.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| Status CacheServer::DoServiceStart() { | |||
| if (!top_.empty()) { | |||
| Path spill(top_); | |||
| RETURN_IF_NOT_OK(spill.CreateDirectories()); | |||
| MS_LOG(INFO) << "CacheServer will use disk folder: " << top_; | |||
| } | |||
| RETURN_IF_NOT_OK(vg_.ServiceStart()); | |||
| cache_q_ = std::make_shared<Queue<BaseRequest *>>(1024); | |||
| RETURN_IF_NOT_OK(cache_q_->Register(&vg_)); | |||
| auto f = std::bind(&CacheServer::ServerRequest, this); | |||
| // Spawn a a few threads to serve the request. | |||
| for (auto i = 0; i < num_workers_; ++i) { | |||
| RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Cache server", f)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheServer::DoServiceStop() { | |||
| Status rc; | |||
| Status rc2; | |||
| // First stop all the threads. | |||
| RETURN_IF_NOT_OK(vg_.ServiceStop()); | |||
| // Clean up all the caches if any. | |||
| UniqueLock lck(&rwLock_); | |||
| auto it = all_caches_.begin(); | |||
| while (it != all_caches_.end()) { | |||
| auto cs = std::move(it->second); | |||
| rc2 = cs->ServiceStop(); | |||
| if (rc2.IsError()) { | |||
| rc = rc2; | |||
| } | |||
| ++it; | |||
| } | |||
| return rc; | |||
| } | |||
| CacheService *CacheServer::GetService(connection_id_type id) const { | |||
| SharedLock lck(&rwLock_); | |||
| auto it = all_caches_.find(id); | |||
| if (it != all_caches_.end()) { | |||
| return it->second.get(); | |||
| } | |||
| return nullptr; | |||
| } | |||
| Status CacheServer::CreateService(connection_id_type connection_id, uint64_t cache_mem_sz, | |||
| BaseRequest::CreateCacheFlag flag, std::string *out_cookie) { | |||
| // We can't do spilling unless this server is setup with a spill path in the first place | |||
| bool spill = (flag & BaseRequest::CreateCacheFlag::kSpillToDisk) == BaseRequest::CreateCacheFlag::kSpillToDisk; | |||
| bool generate_id = | |||
| (flag & BaseRequest::CreateCacheFlag::kGenerateRowId) == BaseRequest::CreateCacheFlag::kGenerateRowId; | |||
| if (spill && top_.empty()) { | |||
| RETURN_STATUS_UNEXPECTED("Server is not set up with spill support."); | |||
| } | |||
| RETURN_UNEXPECTED_IF_NULL(out_cookie); | |||
| *out_cookie = ""; | |||
| // Before creating the cache, first check if this is a request for a shared usage of an existing cache | |||
| // If two CreateService come in with identical connection_id, we need to serialize the create. | |||
| // The first create will be successful and be given a special cookie. | |||
| UniqueLock lck(&rwLock_); | |||
| auto end = all_caches_.end(); | |||
| auto it = all_caches_.find(connection_id); | |||
| if (it == end) { | |||
| std::unique_ptr<CacheService> cs; | |||
| try { | |||
| cs = std::make_unique<CacheService>(cache_mem_sz, spill ? top_ : "", generate_id); | |||
| RETURN_IF_NOT_OK(cs->ServiceStart()); | |||
| *out_cookie = cs->cookie(); | |||
| all_caches_.emplace(connection_id, std::move(cs)); | |||
| } catch (const std::bad_alloc &e) { | |||
| return Status(StatusCode::kOutOfMemory); | |||
| } | |||
| } else { | |||
| MS_LOG(INFO) << "Duplicate request for " + std::to_string(connection_id) + " to create cache service"; | |||
| // We can return OK but we will return a duplicate key so user can act accordingly to either ignore it | |||
| // treat it as OK. | |||
| return Status(StatusCode::kDuplicateKey); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| /// This is the main loop the cache server thread(s) are running. | |||
| /// Each thread will pop a request and save the result in the same request. | |||
| /// The sender will wait on the wait post in the request. Once the request | |||
| /// is fulfilled, the server thread will do a post signalling the request is | |||
| /// is processed. | |||
| /// \return | |||
| Status CacheServer::ServerRequest() { | |||
| TaskManager::FindMe()->Post(); | |||
| // Loop forever until we are interrupted. | |||
| while (true) { | |||
| BaseRequest *base_rq = nullptr; | |||
| RETURN_IF_NOT_OK(cache_q_->PopFront(&base_rq)); | |||
| auto cs = GetService(base_rq->connection_id_); | |||
| // Except for creating a new session, we expect cs is not null. | |||
| switch (base_rq->type_) { | |||
| case BaseRequest::RequestType::kCacheRow: { | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Cache id " + std::to_string(base_rq->connection_id_) + " not found"; | |||
| base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } else { | |||
| auto *rq = reinterpret_cast<CacheRowRequest *>(base_rq); | |||
| // Only if the cookie matches, we can accept insert into this cache that has a build phase | |||
| if (!cs->HasBuildPhase() || rq->cookie_ == cs->cookie()) { | |||
| rq->rc_ = cs->CacheRow(rq->buffers_, &rq->row_id_from_server_); | |||
| } else { | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch"); | |||
| } | |||
| } | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kBatchFetchRows: { | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Cache id " + std::to_string(base_rq->connection_id_) + " not found"; | |||
| base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } else { | |||
| auto *rq = reinterpret_cast<BatchFetchRequest *>(base_rq); | |||
| rq->rc_ = cs->BatchFetch(rq->row_id_, &rq->mem_); | |||
| } | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kCreateCache: { | |||
| // If the cache is already created we still need to run the creation so that we do sanity checks on the | |||
| // client id and return the cache id back to the user. | |||
| auto *rq = reinterpret_cast<CreationCacheRequest *>(base_rq); | |||
| rq->rc_ = CreateService(rq->connection_id_, rq->cache_mem_sz, rq->flag_, &rq->cookie_); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kPurgeCache: { | |||
| if (cs != nullptr) { | |||
| base_rq->rc_ = cs->Purge(); | |||
| } else { | |||
| // it is already purged. Ignore it. | |||
| base_rq->rc_ = Status::OK(); | |||
| } | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kDestroyCache: { | |||
| if (cs != nullptr) { | |||
| // We need a strong lock to protect the map. | |||
| connection_id_type id = base_rq->connection_id_; | |||
| UniqueLock lck(&rwLock_); | |||
| // std::map will invoke the constructor of CacheService. So we don't need to do anything here. | |||
| auto n = all_caches_.erase(id); | |||
| if (n == 0) { | |||
| // It has been destroyed by another duplicate request. | |||
| MS_LOG(INFO) << "Duplicate request for " + std::to_string(id) + " to create cache service"; | |||
| } | |||
| base_rq->rc_ = Status::OK(); | |||
| } else { | |||
| // it is already destroyed. Ignore it. | |||
| base_rq->rc_ = Status::OK(); | |||
| } | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kGetStat: { | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found"; | |||
| base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } else { | |||
| auto *rq = reinterpret_cast<GetStatRequest *>(base_rq); | |||
| CacheService::ServiceStat svc_stat; | |||
| rq->rc_ = cs->GetStat(&svc_stat); | |||
| if (rq->rc_.IsOk()) { | |||
| flatbuffers::FlatBufferBuilder fbb; | |||
| ServiceStatMsgBuilder bld(fbb); | |||
| bld.add_num_disk_cached(svc_stat.stat_.num_disk_cached); | |||
| bld.add_num_mem_cached(svc_stat.stat_.num_mem_cached); | |||
| bld.add_max_row_id(svc_stat.max_); | |||
| bld.add_min_row_id(svc_stat.min_); | |||
| bld.add_state(svc_stat.state_); | |||
| auto offset = bld.Finish(); | |||
| fbb.Finish(offset); | |||
| rq->rc_ = rq->mem_.allocate(fbb.GetSize()); | |||
| if (rq->rc_.IsOk()) { | |||
| WritableSlice dest(rq->mem_.GetMutablePointer(), fbb.GetSize()); | |||
| ReadableSlice src(fbb.GetBufferPointer(), fbb.GetSize()); | |||
| RETURN_IF_NOT_OK(WritableSlice::Copy(&dest, src)); | |||
| } | |||
| } | |||
| } | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kCacheSchema: { | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found"; | |||
| base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } else { | |||
| auto *rq = reinterpret_cast<CacheSchemaRequest *>(base_rq); | |||
| rq->rc_ = cs->CacheSchema(rq->buf_, rq->len_of_buf_); | |||
| } | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kFetchSchema: { | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found"; | |||
| base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } else { | |||
| auto *rq = reinterpret_cast<FetchSchemaRequest *>(base_rq); | |||
| rq->rc_ = cs->FetchSchema(&rq->mem_); | |||
| } | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kBuildPhaseDone: { | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found"; | |||
| base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } else { | |||
| auto *rq = reinterpret_cast<BuildPhaseDoneRequest *>(base_rq); | |||
| // We can only allow to switch phase is the cookie match. | |||
| if (rq->cookie_ == cs->cookie()) { | |||
| rq->rc_ = cs->BuildPhaseDone(); | |||
| } else { | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch"); | |||
| } | |||
| } | |||
| break; | |||
| } | |||
| default: | |||
| base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Unknown request type"); | |||
| } | |||
| // Notify it is done, and move on to the next request. | |||
| base_rq->wp_.Set(); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| CacheServer::CacheServer(const std::string &spill_path, int32_t num_workers) | |||
| : top_(spill_path), num_workers_(num_workers) {} | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,98 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_CACHE_SERVER_H_ | |||
| #define DATASET_ENGINE_CACHE_SERVER_H_ | |||
| #include <algorithm> | |||
| #include <atomic> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include <map> | |||
| #include "dataset/engine/cache/cache_service.h" | |||
| #include "dataset/core/tensor.h" | |||
| #include "dataset/util/arena.h" | |||
| #include "dataset/util/cache_pool.h" | |||
| #include "dataset/util/lock.h" | |||
| #include "dataset/util/service.h" | |||
| #include "dataset/util/services.h" | |||
| #include "dataset/util/system_pool.h" | |||
| #include "dataset/util/queue.h" | |||
| #include "dataset/util/task_manager.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class BaseRequest; | |||
| /// \brief A server which provides CacheService services. | |||
| class CacheServer : public Service { | |||
| public: | |||
| friend class Services; | |||
| using cache_index = std::map<connection_id_type, std::unique_ptr<CacheService>>; | |||
| CacheServer(const CacheServer &) = delete; | |||
| CacheServer &operator=(const CacheServer &) = delete; | |||
| CacheServer(CacheServer &&) = delete; | |||
| CacheServer &operator=(CacheServer &) = delete; | |||
| static CacheServer &GetInstance() noexcept { return Services::getCacheServer(); } | |||
| Status DoServiceStart() override; | |||
| Status DoServiceStop() override; | |||
| ~CacheServer() { (void)ServiceStop(); } | |||
| /// \brief For the current demonstration, a cache client contacts cache server using a Queue. | |||
| /// \param rq | |||
| /// \return Status object | |||
| Status PushRequest(BaseRequest *rq) { | |||
| RETURN_UNEXPECTED_IF_NULL(rq); | |||
| RETURN_IF_NOT_OK(cache_q_->Add(rq)); | |||
| return Status::OK(); | |||
| } | |||
| private: | |||
| mutable RWLock rwLock_; | |||
| std::string top_; | |||
| cache_index all_caches_; | |||
| std::shared_ptr<Queue<BaseRequest *>> cache_q_; | |||
| TaskGroup vg_; | |||
| int32_t num_workers_; | |||
| /// \brief Constructor | |||
| /// \param spill_path Top directory for spilling buffers to. | |||
| /// \param num_workers Number of threads for handling requests. | |||
| explicit CacheServer(const std::string &spill_path, int32_t num_workers = 3); | |||
| /// \brief Locate a cache service from connection id. | |||
| /// \return Pointer to cache service. Null if not found | |||
| CacheService *GetService(connection_id_type id) const; | |||
| /// \brief Create a cache service. We allow multiple clients to create the same cache service. | |||
| /// Subsequent duplicate requests are ignored. The first cache client to create the service will be given | |||
| /// a special unique cookie. | |||
| /// \param[in] connection_id This is from a Cache client. | |||
| /// \param[in] cache_mem_sz | |||
| /// \param[in] flag | |||
| /// \param[out] out_cookie Only the first cache client will be given a special cookie to identify the creator | |||
| /// \return Status object | |||
| Status CreateService(connection_id_type connection_id, uint64_t cache_mem_sz, BaseRequest::CreateCacheFlag flag, | |||
| std::string *out_cookie); | |||
| /// \brief Entry point for all server threads. | |||
| Status ServerRequest(); | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_CORE_CACHE_TENSOR_H_ | |||
| @@ -0,0 +1,265 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/engine/cache/cache_service.h" | |||
| #include "dataset/util/slice.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| CacheService::CacheService(uint64_t mem_sz, const std::string &root, bool generate_id) | |||
| : root_(root), | |||
| cache_mem_sz_(mem_sz), | |||
| cp_(nullptr), | |||
| map_(nullptr), | |||
| next_id_(0), | |||
| generate_id_(generate_id), | |||
| schema_key_(-1), | |||
| st_(generate_id ? State::kBuildPhase : State::kNone) {} | |||
| CacheService::~CacheService() { (void)ServiceStop(); } | |||
| bool CacheService::UseArena() { | |||
| // If fixed size, use Arena instead of the pool from global context. | |||
| return (cache_mem_sz_ > 0); | |||
| } | |||
| Status CacheService::DoServiceStart() { | |||
| std::shared_ptr<MemoryPool> mp_; | |||
| if (UseArena()) { | |||
| // Create a fixed size arena based on the parameter. | |||
| std::shared_ptr<Arena> arena; | |||
| RETURN_IF_NOT_OK(Arena::CreateArena(&arena, cache_mem_sz_)); | |||
| mp_ = std::move(arena); | |||
| } else { | |||
| // Unlimited size. Simply use a system pool. Another choice is CircularPool. | |||
| mp_ = std::make_shared<SystemPool>(); | |||
| } | |||
| // Put together a CachePool for backing up the Tensor | |||
| cp_ = std::make_shared<CachePool>(CachePool::value_allocator(mp_), root_); | |||
| RETURN_IF_NOT_OK(cp_->ServiceStart()); | |||
| // Set up the B+ tree as well. But use the system pool instead. | |||
| map_ = std::make_shared<row_map>(); | |||
| // Assign a name to this cache. Used for exclusive connection. But we can just use CachePool's name. | |||
| cookie_ = cp_->MyName(); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheService::DoServiceStop() { | |||
| if (cp_ != nullptr) { | |||
| RETURN_IF_NOT_OK(cp_->ServiceStop()); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type *row_id_generated) { | |||
| SharedLock rw(&rw_lock_); | |||
| RETURN_UNEXPECTED_IF_NULL(row_id_generated); | |||
| if (st_ == State::kFetchPhase) { | |||
| // For this kind of cache service, once we are done with the build phase into fetch phase, we can't | |||
| // allow other to cache more rows. | |||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | |||
| } | |||
| try { | |||
| // The first buffer is a flatbuffer which describes the rest of the buffers follow | |||
| auto fb = buf.front(); | |||
| RETURN_UNEXPECTED_IF_NULL(fb); | |||
| auto msg = GetTensorRowHeaderMsg(fb); | |||
| // If the server side is designed to ignore incoming row id, we generate row id. | |||
| if (generate_id_) { | |||
| *row_id_generated = GetNextRowId(); | |||
| // Some debug information on how many rows we have generated so far. | |||
| if ((*row_id_generated) % 1000 == 0) { | |||
| MS_LOG(DEBUG) << "Number of rows cached: " << *row_id_generated; | |||
| } | |||
| } else { | |||
| if (msg->row_id() < 0) { | |||
| std::string errMsg = "Expect positive row id: " + std::to_string(msg->row_id()); | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| *row_id_generated = msg->row_id(); | |||
| } | |||
| auto size_of_this = msg->size_of_this(); | |||
| auto column_hdr = msg->column(); | |||
| // Number of tensor buffer should match the number of columns plus one. | |||
| if (buf.size() != column_hdr->size() + 1) { | |||
| std::string errMsg = "Column count does not match. Expect " + std::to_string(column_hdr->size() + 1) + | |||
| " but get " + std::to_string(buf.size()); | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| // Next we store in either memory or on disk. Low level code will consolidate everything in one piece. | |||
| std::vector<ReadableSlice> all_data; | |||
| all_data.reserve(column_hdr->size() + 1); | |||
| all_data.emplace_back(fb, size_of_this); | |||
| for (auto i = 0; i < column_hdr->size(); ++i) { | |||
| all_data.emplace_back(buf.at(i + 1), msg->data_sz()->Get(i)); | |||
| } | |||
| // Now we cache the flat buffer. | |||
| CachePool::key_type key; | |||
| RETURN_IF_NOT_OK(cp_->Insert(all_data, &key)); | |||
| Status rc = map_->DoInsert(*row_id_generated, key); | |||
| if (rc == Status(StatusCode::kDuplicateKey)) { | |||
| MS_LOG(DEBUG) << "Ignoring duplicate key"; | |||
| } else { | |||
| RETURN_IF_NOT_OK(rc); | |||
| } | |||
| return Status::OK(); | |||
| } catch (const std::exception &e) { | |||
| RETURN_STATUS_UNEXPECTED(e.what()); | |||
| } | |||
| } | |||
| std::ostream &operator<<(std::ostream &out, const CacheService &cs) { | |||
| // Then show any custom derived-internal stuff | |||
| out << "\nCache memory size: " << cs.cache_mem_sz_; | |||
| out << "\nSpill path: "; | |||
| if (cs.root_.empty()) { | |||
| out << "None"; | |||
| } else { | |||
| out << cs.GetSpillPath(); | |||
| } | |||
| return out; | |||
| } | |||
| Path CacheService::GetSpillPath() const { return cp_->GetSpillPath(); } | |||
| Status CacheService::Purge() { | |||
| // First we must lock exclusively. No one else can cache/restore anything. | |||
| UniqueLock rw(&rw_lock_); | |||
| RETURN_IF_NOT_OK(cp_->ServiceStop()); | |||
| auto new_map = std::make_shared<row_map>(); | |||
| map_.reset(); | |||
| map_ = std::move(new_map); | |||
| next_id_ = 0; | |||
| RETURN_IF_NOT_OK(cp_->ServiceStart()); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheService::GetStat(CacheService::ServiceStat *out) { | |||
| SharedLock rw(&rw_lock_); | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| if (st_ == State::kNone || st_ == State::kFetchPhase) { | |||
| out->stat_ = cp_->GetStat(); | |||
| out->state_ = static_cast<ServiceStat::state_type>(st_); | |||
| auto it = map_->begin(); | |||
| if (it != map_->end()) { | |||
| out->min_ = it.key(); | |||
| auto end_it = map_->end(); | |||
| --end_it; | |||
| out->max_ = end_it.key(); | |||
| } | |||
| } else { | |||
| out->state_ = static_cast<ServiceStat::state_type>(st_); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheService::BatchFetch(const std::vector<row_id_type> &v, MemGuard<uint8_t> *out) const { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| SharedLock rw(&rw_lock_); | |||
| if (st_ == State::kBuildPhase) { | |||
| // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. | |||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | |||
| } | |||
| const auto num_elements = v.size(); | |||
| int64_t mem_sz = (num_elements + 1) * sizeof(int64_t); | |||
| int64_t data_offset = mem_sz; | |||
| std::vector<int64_t> sz_v; | |||
| std::vector<CachePool::key_type> keys; | |||
| sz_v.reserve(num_elements); | |||
| keys.reserve(num_elements); | |||
| for (auto row_id : v) { | |||
| auto r = map_->Search(row_id); | |||
| if (r.second) { | |||
| auto &it = r.first; | |||
| CachePool::key_type key = it.value(); | |||
| auto sz = cp_->GetSize(key); | |||
| if (sz == 0) { | |||
| std::string errMsg = "Key not found: "; | |||
| errMsg += std::to_string(key); | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| keys.push_back(key); | |||
| sz_v.push_back(sz); | |||
| mem_sz += sz; | |||
| } else { | |||
| keys.push_back(-1); | |||
| sz_v.push_back(0); | |||
| } | |||
| } | |||
| MemGuard<uint8_t> mem; | |||
| RETURN_IF_NOT_OK(mem.allocate(mem_sz)); | |||
| auto *offset_array = reinterpret_cast<int64_t *>(mem.GetMutablePointer()); | |||
| offset_array[0] = data_offset; | |||
| WritableSlice all(mem.GetMutablePointer(), mem.GetSizeInBytes()); | |||
| for (auto i = 0; i < num_elements; ++i) { | |||
| auto sz = sz_v.at(i); | |||
| offset_array[i + 1] = offset_array[i] + sz; | |||
| if (sz > 0) { | |||
| WritableSlice row_data(all, offset_array[i], sz); | |||
| auto key = keys.at(i); | |||
| size_t bytesRead = 0; | |||
| RETURN_IF_NOT_OK(cp_->Read(key, &row_data, &bytesRead)); | |||
| if (bytesRead != sz) { | |||
| MS_LOG(ERROR) << "Unexpected length. Read " << bytesRead << ". Expected " << sz << "." | |||
| << " Internal key: " << key << "\n"; | |||
| RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details."); | |||
| } | |||
| } | |||
| } | |||
| *out = std::move(mem); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheService::CacheSchema(const void *buf, int64_t len) { | |||
| SharedLock rw(&rw_lock_); | |||
| if (st_ == State::kFetchPhase) { | |||
| // For this kind of cache service, once we are done with the build phase into fetch phase, we can't | |||
| // allow other to cache more rows. | |||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | |||
| } | |||
| // This is a special request and we need to remember where we store it. | |||
| // In case we are calling the same function from multiple threads, only | |||
| // the first one is considered. Rest is ignored. | |||
| CachePool::key_type cur_key = schema_key_; | |||
| CachePool::key_type key; | |||
| if (cur_key < 0) { | |||
| RETURN_IF_NOT_OK(cp_->Insert({ReadableSlice(buf, len)}, &key)); | |||
| auto result = std::atomic_compare_exchange_strong(&schema_key_, &cur_key, key); | |||
| MS_LOG(DEBUG) << "Caching Schema. Result = " << result; | |||
| } else { | |||
| MS_LOG(DEBUG) << "Caching Schema already done"; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheService::FetchSchema(MemGuard<uint8_t> *out) const { | |||
| SharedLock rw(&rw_lock_); | |||
| if (st_ == State::kBuildPhase) { | |||
| // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. | |||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | |||
| } | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| MemGuard<uint8_t> mem; | |||
| if (schema_key_ >= 0) { | |||
| auto len = cp_->GetSize(schema_key_); | |||
| RETURN_IF_NOT_OK(mem.allocate(len)); | |||
| auto slice = WritableSlice(mem.GetMutablePointer(), len); | |||
| RETURN_IF_NOT_OK(cp_->Read(schema_key_, &slice)); | |||
| *out = std::move(mem); | |||
| } else { | |||
| return Status(StatusCode::kFileNotExist, __LINE__, __FILE__, "No schema has been cached"); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheService::BuildPhaseDone() { | |||
| if (HasBuildPhase()) { | |||
| // Exclusive lock to switch phase | |||
| UniqueLock rw(&rw_lock_); | |||
| st_ = State::kFetchPhase; | |||
| return Status::OK(); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Not a cache that has a build phase"); | |||
| } | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,143 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_CACHE_SERVICE_H_ | |||
| #define DATASET_ENGINE_CACHE_SERVICE_H_ | |||
| #include <algorithm> | |||
| #include <atomic> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <type_traits> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "./de_tensor_generated.h" | |||
| #include "dataset/core/global_context.h" | |||
| #include "dataset/core/tensor.h" | |||
| #include "dataset/engine/cache/cache_request.h" | |||
| #include "dataset/util/arena.h" | |||
| #include "dataset/util/btree.h" | |||
| #include "dataset/util/cache_pool.h" | |||
| #include "dataset/util/service.h" | |||
| #include "dataset/util/services.h" | |||
| #include "dataset/util/system_pool.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| struct CacheStat; | |||
| /// \brief A cache service for storing/fetching buffers to in memory cache and may spill to disk the cache service is | |||
| /// created to support spilling | |||
| class CacheService : public Service { | |||
| public: | |||
| friend class CacheServer; | |||
| using row_map = BPlusTree<row_id_type, CachePool::key_type>; | |||
| enum class State : uint8_t { kNone = 0, kBuildPhase, kFetchPhase }; | |||
| /// \brief Constructor | |||
| /// \param mem_sz Memory size to be set aside for the in memory cache. 0 means unlimited | |||
| /// \param root Spill path. Empty string means no spilling | |||
| /// \param generate_id If the cache service should generate row id for buffer that is cached. | |||
| /// For non-mappable dataset, this should be set to true. | |||
| CacheService(uint64_t mem_sz, const std::string &root, bool generate_id); | |||
| ~CacheService(); | |||
| /// \brief For fixed size memory, we will create an Arena. | |||
| /// \return false if unlimited memory. | |||
| bool UseArena(); | |||
| Status DoServiceStart() override; | |||
| Status DoServiceStop() override; | |||
| /// \brief Main function to cache a row which is in form a series of buffers. | |||
| /// The first buffer is a Google flatbuffer which describes the rest of the buffers followed. | |||
| /// \param[in] buf Vector of buffer | |||
| /// \param[out] row_id_generated The row id assigned to this row if any | |||
| /// \return Status object | |||
| Status CacheRow(const std::vector<const void *> &buf, row_id_type *row_id_generated); | |||
| /// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded | |||
| /// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row. | |||
| /// \param[in] v A vector of row id. | |||
| /// \param[out] out A contiguous memory buffer that holds the requested rows. | |||
| /// \return Status object | |||
| Status BatchFetch(const std::vector<row_id_type> &v, MemGuard<uint8_t> *out) const; | |||
| /// \brief Getter function | |||
| /// \return Spilling path | |||
| Path GetSpillPath() const; | |||
| /// \brief A structure returned from the cache server for statistics request. | |||
| class ServiceStat { | |||
| public: | |||
| using state_type = std::underlying_type<State>::type; | |||
| ServiceStat() : min_(0), max_(0), state_(0) {} | |||
| CachePool::CacheStat stat_{}; | |||
| row_id_type min_; | |||
| row_id_type max_; | |||
| state_type state_; | |||
| }; | |||
| /// \brief Statistics for the current service | |||
| /// \param[in/out] A pointer to a pre-allocated ServiceStat structure | |||
| /// \return Status Object | |||
| Status GetStat(ServiceStat *); | |||
| /// \brief Cache schema | |||
| /// \param buf A Google Flatbuffer that contains the schema | |||
| /// \param len size of the buffer | |||
| /// \return Status object | |||
| Status CacheSchema(const void *buf, int64_t len); | |||
| /// \brief Fetch schema | |||
| /// \param out A contiguous memory that contains the serialized form of schema. | |||
| /// \return Status object | |||
| Status FetchSchema(MemGuard<uint8_t> *out) const; | |||
| /// \brief Purge the content of a cache | |||
| /// \return Status object | |||
| Status Purge(); | |||
| /// \brief Overload the << operator to print a cache service | |||
| /// \param out std::ostream | |||
| /// \param cs A cache service | |||
| /// \return std::ostream | |||
| friend std::ostream &operator<<(std::ostream &out, const CacheService &cs); | |||
| /// \brief Every cache service has a cookie. If the cookie of a CacheClient matches this cookie, this CacheClient | |||
| /// is the creator | |||
| /// \return Cookie | |||
| std::string cookie() const { return cookie_; } | |||
| /// \brief If this cache service generates row id for buffer cached, it is divided into two phases, a build phase and | |||
| /// a read phase. | |||
| /// \return True if has two phases. | |||
| bool HasBuildPhase() const { return generate_id_; } | |||
| /// \brief Change from write phase to read phase. Only the creator of this service is allowed to make this call. | |||
| /// \return Status object | |||
| Status BuildPhaseDone(); | |||
| private: | |||
| mutable RWLock rw_lock_; | |||
| std::string root_; | |||
| uint64_t cache_mem_sz_; | |||
| std::shared_ptr<CachePool> cp_; | |||
| std::shared_ptr<row_map> map_; | |||
| std::atomic<row_id_type> next_id_; | |||
| bool generate_id_; | |||
| std::atomic<CachePool::key_type> schema_key_; | |||
| std::string cookie_; | |||
| State st_; | |||
| /// \brief Private function to generate a row id | |||
| /// \return Row id assigned. | |||
| row_id_type GetNextRowId() { return next_id_.fetch_add(1); } | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_CACHE_SERVICE_H_ | |||
| @@ -0,0 +1,81 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| namespace mindspore.dataset; | |||
| /// Type of a Tensor | |||
| enum TensorType : byte { | |||
| DE_UNKNOWN = 0, | |||
| DE_BOOL = 1, | |||
| DE_INT8 = 2, | |||
| DE_UINT8 = 3, | |||
| DE_INT16 = 4, | |||
| DE_UINT16 = 5, | |||
| DE_INT32 = 6, | |||
| DE_UINT32 = 7, | |||
| DE_INT64 = 8, | |||
| DE_UINT64 = 9, | |||
| DE_FLOAT16 = 10, | |||
| DE_FLOAT32 = 11, | |||
| DE_FLOAT64 = 12, | |||
| DE_STRING = 13 | |||
| } | |||
| /// The meta information of a Tensor | |||
| /// \note Only the type and shape are considered meta information. Tensor data is excluded. | |||
| table TensorMetaMsg { | |||
| dims:[int64] (required); | |||
| type:TensorType; | |||
| } | |||
| /// This is the first buffer that is sent to a Cache server when a TensorRow is serialized. | |||
| /// \param row_id is the row id of the TensorRow. | |||
| /// \param column The meta information of each Tensor in the row | |||
| /// \param size of this serialized buffer | |||
| /// \param size of each tensor data buffer that follows | |||
| table TensorRowHeaderMsg { | |||
| row_id:int64; | |||
| column:[TensorMetaMsg] (required); | |||
| size_of_this:int64; | |||
| data_sz:[int64] (required); | |||
| } | |||
| root_type TensorRowHeaderMsg; | |||
| /// A row of row id's | |||
| table TensorRowIds { | |||
| row_id:[int64] (required); | |||
| } | |||
| /// Statistics returned from each cache service | |||
| /// \note It must match CacheService::ServiceStat | |||
| table ServiceStatMsg { | |||
| num_mem_cached:int64; | |||
| num_disk_cached:int64; | |||
| min_row_id:int64; | |||
| max_row_id:int64; | |||
| state:int8; | |||
| } | |||
| /// Column description of each column in a schema | |||
| table ColumnNameMsg { | |||
| name:string; | |||
| id:int32; | |||
| } | |||
| /// Serialized form of a schema | |||
| table SchemaMsg { | |||
| column:[ColumnNameMsg]; | |||
| } | |||
| @@ -24,10 +24,8 @@ namespace dataset { | |||
| // Description: This is the main constructor that is used for making a buffer | |||
| DataBuffer::DataBuffer(int32_t id, BufferFlags flags) : buffer_id_(id), tensor_table_(nullptr), buffer_flags_(flags) {} | |||
| // Name: print() | |||
| // Description: A function that prints info about the DataBuffer (base class version) | |||
| void DataBuffer::Print(std::ostream &out, // In: The output stream to print to | |||
| bool show_all) const { // In: T/F if it should show everything | |||
| // A method for debug printing of the buffer | |||
| void DataBuffer::Print(std::ostream &out, bool show_all) const { | |||
| out << "bufferId: " << buffer_id_ << "\nflags: " << std::hex << buffer_flags_ << std::dec << "\n"; | |||
| // If the column counts are set then it means that data has been set into | |||
| @@ -46,11 +44,6 @@ void DataBuffer::Print(std::ostream &out, // In: The output stream to print | |||
| } | |||
| } | |||
| Status DataBuffer::Load() { | |||
| std::string err_msg = "Base class load called, but it does not have an implementation!"; | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| // Remove me!! Callers should fetch rows via pop | |||
| Status DataBuffer::GetTensor(std::shared_ptr<Tensor> *ptr, int32_t row_id, int32_t col_id) const { | |||
| if (row_id < tensor_table_->size() && col_id < tensor_table_->at(row_id).size()) { | |||
| @@ -92,8 +85,5 @@ Status DataBuffer::SliceOff(int64_t number_of_rows) { | |||
| return Status::OK(); | |||
| } | |||
| // Destructor | |||
| DataBuffer::~DataBuffer() {} | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -29,11 +29,9 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // The DataBuffer class is a base class that will represent the data for n values based | |||
| // on a unique row id for each row of data. | |||
| // There can be different types of DataBuffers to abstract over how the data is stored | |||
| // in memory and acquired from storage. | |||
| // Each buffer holds a range of consecutive row id's. | |||
| /// \brief The DataBuffer class is a container of tensor data and is the unit of transmission between | |||
| /// connectors of dataset operators. Inside the buffer, tensors are organized into a table-like format | |||
| /// where n TensorRows may consist of m tensors (columns). | |||
| class DataBuffer { | |||
| public: | |||
| // Buffer flags | |||
| @@ -47,13 +45,13 @@ class DataBuffer { | |||
| // Description: This is the main constructor that is used for making a buffer | |||
| DataBuffer(int32_t id, BufferFlags flags); | |||
| // Destructor | |||
| virtual ~DataBuffer(); | |||
| /// \brief default destructor | |||
| ~DataBuffer() = default; | |||
| // Name: print() | |||
| // Description: A function that prints info about the DataBuffer (base class version) | |||
| virtual void Print(std::ostream &out, // In: The output stream to print to | |||
| bool show_all) const; // In: T/F if it should show everything | |||
| /// \brief A method for debug printing of the buffer | |||
| /// \param[inout] out The stream to write to | |||
| /// \param[in] show_all A boolean to toggle between details and summary printing | |||
| void Print(std::ostream &out, bool show_all) const; | |||
| // Provide stream operator for displaying it | |||
| friend std::ostream &operator<<(std::ostream &out, const DataBuffer &cb) { | |||
| @@ -61,10 +59,6 @@ class DataBuffer { | |||
| return out; | |||
| } | |||
| // Name: load() | |||
| // Description: populates the DataBuffer with data based on it's id | |||
| virtual Status Load(); | |||
| // Convenience getter functions for flag checking | |||
| bool eof() const { return (static_cast<uint32_t>(buffer_flags_) & static_cast<uint32_t>(kDeBFlagEOF)); } | |||
| @@ -17,7 +17,11 @@ set(DATASET_ENGINE_DATASETOPS_SRC_FILES | |||
| take_op.cc | |||
| shuffle_op.cc | |||
| zip_op.cc | |||
| concat_op.cc | |||
| concat_op.cc | |||
| cache_base_op.cc | |||
| cache_lookup_op.cc | |||
| cache_op.cc | |||
| cache_merge_op.cc | |||
| ) | |||
| if (ENABLE_PYTHON) | |||
| @@ -0,0 +1,185 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/engine/datasetops/cache_base_op.h" | |||
| #include <iomanip> | |||
| #include <iostream> | |||
| #include "dataset/engine/execution_tree.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // A print method typically used for debugging | |||
| void CacheBase::Print(std::ostream &out, bool show_all) const { | |||
| // Always show the id and name as first line regardless if this summary or detailed print | |||
| out << "(" << std::setw(2) << operator_id_ << ") <" << Name() << ">:"; | |||
| if (!show_all) { | |||
| // Call the super class for displaying any common 1-liner info | |||
| ParallelOp::Print(out, show_all); | |||
| out << "\n"; | |||
| } else { | |||
| // Call the super class for displaying any common detailed info | |||
| ParallelOp::Print(out, show_all); | |||
| // Then show any custom derived-internal stuff | |||
| out << "\nCache client:\n" << *cache_client_ << "\n\n"; | |||
| } | |||
| } | |||
| // Overrides base class reset method. When an operator does a reset, it cleans up any state | |||
| // info from it's previous execution and then initializes itself so that it can be executed | |||
| // again. | |||
| Status CacheBase::Reset() { | |||
| if (sampler_ != nullptr) { | |||
| RETURN_IF_NOT_OK(sampler_->ResetSampler()); | |||
| } | |||
| // Wake up the workers to get them going again in a new epoch | |||
| MS_LOG(DEBUG) << Name() << " resetting."; | |||
| epoch_sync_.Set(); | |||
| return Status::OK(); | |||
| } | |||
| CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, | |||
| std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler) | |||
| : ParallelOp(num_workers, op_connector_size, sampler), | |||
| cache_client_(cache_client), | |||
| rows_per_buffer_(rows_per_buf), | |||
| // We can cause deadlock if this internal Connector size is too small. | |||
| keys_miss_(num_workers_, 1, 1024) { | |||
| io_block_queues_.Init(num_workers, op_connector_size); | |||
| } | |||
| // Common function to fetch samples from the sampler and send them using the io_block_queues to | |||
| // the parallel workers | |||
| Status CacheBase::FetchSamplesToWorkers() { | |||
| int64_t buf_cnt = 0; | |||
| int64_t wait_cnt = 0; | |||
| do { | |||
| epoch_sync_.Clear(); | |||
| std::vector<row_id_type> keys; | |||
| int64_t row_cnt = 0; | |||
| keys.reserve(rows_per_buffer_); | |||
| std::unique_ptr<DataBuffer> sampler_buffer; | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||
| while (!sampler_buffer->eoe()) { | |||
| TensorRow sample_row; | |||
| RETURN_IF_NOT_OK(sampler_buffer->PopRow(&sample_row)); | |||
| std::shared_ptr<Tensor> sample_ids = sample_row[0]; | |||
| for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); itr++) { | |||
| keys.push_back(*itr); | |||
| ++row_cnt; | |||
| if (row_cnt % rows_per_buffer_ == 0) { | |||
| auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)); | |||
| RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk))); | |||
| keys.clear(); | |||
| } | |||
| } | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||
| } | |||
| if (!keys.empty()) { | |||
| auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)); | |||
| RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk))); | |||
| } | |||
| // send the eoe | |||
| RETURN_IF_NOT_OK( | |||
| io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | |||
| // If repeat but the not last repeat, wait for reset. | |||
| if (BitTest(op_ctrl_flags_, kDeOpRepeated) && !BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { | |||
| MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << ++wait_cnt << " Buffer sent " << buf_cnt; | |||
| RETURN_IF_NOT_OK(epoch_sync_.Wait()); | |||
| } else { | |||
| // We can break out from the loop. | |||
| break; | |||
| } | |||
| } while (true); | |||
| // Flow the eof before exit | |||
| RETURN_IF_NOT_OK( | |||
| io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof))); | |||
| // Ask all the workers to quit. | |||
| for (int32_t i = 0; i < num_workers_; i++) { | |||
| RETURN_IF_NOT_OK( | |||
| io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone))); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheBase::FetchFromCache(int32_t worker_id) { | |||
| int64_t buffer_id = worker_id; | |||
| std::unique_ptr<IOBlock> blk; | |||
| do { | |||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&blk)); | |||
| if (blk->eof()) { | |||
| RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF))); | |||
| } else if (blk->eoe()) { | |||
| if (AllowCacheMiss()) { | |||
| // This code path is for CacheLookupOp acting as a sampler. If we get a eoe from | |||
| // a sampler, send a eoe to physical leaf op as well. | |||
| std::vector<row_id_type> eoe; | |||
| eoe.push_back(eoe_row_id); | |||
| RETURN_IF_NOT_OK(keys_miss_.Push(worker_id, eoe)); | |||
| } | |||
| RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))); | |||
| } else { | |||
| std::vector<int64_t> keys; | |||
| RETURN_IF_NOT_OK(blk->GetKeys(&keys)); | |||
| if (keys.empty()) { | |||
| // empty key is a quit signal for workers | |||
| break; | |||
| } | |||
| std::unique_ptr<DataBuffer> db = std::make_unique<DataBuffer>(buffer_id, DataBuffer::kDeBFlagNone); | |||
| std::unique_ptr<TensorQTable> que = std::make_unique<TensorQTable>(); | |||
| TensorTable ttbl; | |||
| RETURN_IF_NOT_OK(cache_client_->GetRows(keys, &ttbl)); | |||
| auto row_it = ttbl.begin(); | |||
| std::vector<row_id_type> cache_miss; | |||
| cache_miss.reserve(keys.size()); | |||
| for (auto row_id : keys) { | |||
| auto &row = *row_it; | |||
| if (row.empty()) { | |||
| if (AllowCacheMiss()) { | |||
| cache_miss.push_back(row_id); | |||
| } else { | |||
| std::string errMsg = "Row id " + std::to_string(row_id) + " not found."; | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| } | |||
| que->push_back(std::move(row)); | |||
| ++row_it; | |||
| } | |||
| db->set_tensor_table(std::move(que)); | |||
| if (AllowCacheMiss()) { | |||
| // Because of the way connector works, we push unconditionally even cache_miss can be empty. | |||
| RETURN_IF_NOT_OK(keys_miss_.Push(worker_id, cache_miss)); | |||
| } | |||
| RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); | |||
| buffer_id += num_workers_; | |||
| } | |||
| } while (true); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheBase::RegisterResources() { | |||
| RETURN_IF_NOT_OK(epoch_sync_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); | |||
| return Status::OK(); | |||
| } | |||
| CacheBase::~CacheBase() {} | |||
| Status CacheBase::UpdateColumnMapFromCache() { | |||
| Status rc; | |||
| // Get the schema from the server. It may not be there yet. So tolerate the error. | |||
| if (column_name_id_map_.empty()) { | |||
| rc = cache_client_->FetchSchema(&column_name_id_map_); | |||
| if (rc == Status(StatusCode::kFileNotExist)) { | |||
| MS_LOG(DEBUG) << "Schema not in the server yet."; | |||
| rc = Status::OK(); | |||
| } | |||
| } | |||
| return rc; | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,108 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_ | |||
| #define DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "dataset/engine/cache/cache_client.h" | |||
| #include "dataset/engine/cache/cache_service.h" | |||
| #include "dataset/engine/datasetops/parallel_op.h" | |||
| #include "dataset/engine/datasetops/repeat_op.h" | |||
| #include "dataset/engine/datasetops/source/io_block.h" | |||
| #include "dataset/engine/datasetops/source/sampler/sampler.h" | |||
| #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" | |||
| #include "dataset/util/queue.h" | |||
| #include "dataset/util/wait_post.h" | |||
| #include "dataset/engine/datasetops/cache_base_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// \brief This is the base class for CacheOp and CacheLookupOp which share many similarities. | |||
| /// \see CacheOp | |||
| /// \see CacheLookupOp | |||
| class CacheBase : public ParallelOp { | |||
| public: | |||
| /// \brief Base class constructor | |||
| /// \param num_workers Number of parallel workers | |||
| /// \param op_connector_size Connector size | |||
| /// \param rows_per_buf Number of rows per buffer | |||
| /// \param cache_client CacheClient for communication to the CacheServer | |||
| /// \param sampler Sampler which is mandatory | |||
| CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, | |||
| std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler); | |||
| /// \brief Destructor | |||
| ~CacheBase(); | |||
| constexpr static int eoe_row_id = -1; | |||
| /// \brief Overrides base class reset method. When an operator does a reset, it cleans up any state | |||
| /// info from it's previous execution and then initializes itself so that it can be executed | |||
| /// again. | |||
| /// \return Status - The error code return | |||
| Status Reset() override; | |||
| /// \brief A print method typically used for debugging | |||
| /// \param out The output stream to write output to | |||
| /// \param show_all A bool to control if you want to show all info or just a summary | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| /// \brief << Stream output operator overload | |||
| /// \notes This allows you to write the debug print info using stream operators | |||
| /// \param out reference to the output stream being overloaded | |||
| /// \param mo reference to the CacheOp to display | |||
| /// \return the output stream must be returned | |||
| friend std::ostream &operator<<(std::ostream &out, const CacheBase &mo) { | |||
| mo.Print(out, false); | |||
| return out; | |||
| } | |||
| /// \brief Getter for the cache client | |||
| /// \return shared ptr to the cache client | |||
| std::shared_ptr<CacheClient> cache_client() { return cache_client_; } | |||
| /// \brief Setter for the cache client | |||
| void SetCacheClient(std::shared_ptr<CacheClient> cache_client) { cache_client_ = std::move(cache_client); } | |||
| /// \brief Derived class must implement this method if a cache miss is treated as error | |||
| virtual bool AllowCacheMiss() = 0; | |||
| protected: | |||
| std::shared_ptr<CacheClient> cache_client_; | |||
| WaitPost epoch_sync_; | |||
| int32_t rows_per_buffer_; | |||
| Connector<std::vector<row_id_type>> keys_miss_; | |||
| /// \brief Common function to register resources for interrupt | |||
| /// \note Derived should override this function for extra resources to be registered | |||
| virtual Status RegisterResources(); | |||
| /// \brief This function is called by main thread to send samples to the worker thread. | |||
| /// \note It is a non-virtual function | |||
| /// \return Status object | |||
| Status FetchSamplesToWorkers(); | |||
| /// \brief This function is called by each worker to fetch rows from the cache server for a given set of | |||
| /// sample row id's | |||
| /// \return Status object | |||
| Status FetchFromCache(int32_t worker_id); | |||
| /// \brief Get the column map from cache server | |||
| Status UpdateColumnMapFromCache(); | |||
| private: | |||
| QueueList<std::unique_ptr<IOBlock>> io_block_queues_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_ | |||
| @@ -0,0 +1,130 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/engine/datasetops/cache_lookup_op.h" | |||
| #include "dataset/engine/opt/pass.h" | |||
| #include "dataset/core/config_manager.h" | |||
| #include "dataset/core/constants.h" | |||
| #include "dataset/core/global_context.h" | |||
| #include "dataset/engine/execution_tree.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "utils/system/crc32c.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Builder constructor. Creates the builder object. | |||
| CacheLookupOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) { | |||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||
| build_num_workers_ = cfg->num_parallel_workers(); | |||
| rows_per_buffer_ = cfg->rows_per_buffer(); | |||
| build_op_connector_size_ = cfg->op_connector_size(); | |||
| } | |||
| // Check if the required parameters are set by the builder. | |||
| Status CacheLookupOp::Builder::SanityCheck() const { | |||
| if (build_cache_client_ == nullptr) { | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CacheLookupOp requires a CacheClient"); | |||
| } | |||
| // Make sure the cache client has a valid session | |||
| if (!build_cache_client_->session_id()) { | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, | |||
| "Cache client for CacheLookupOp is missing session id"); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // The builder "build" method creates the final object and does some init on it | |||
| Status CacheLookupOp::Builder::Build(std::shared_ptr<CacheLookupOp> *ptr) { | |||
| RETURN_IF_NOT_OK(SanityCheck()); | |||
| *ptr = std::make_shared<CacheLookupOp>(build_num_workers_, build_op_connector_size_, rows_per_buffer_, | |||
| build_cache_client_, build_sampler_); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheLookupOp::operator()() { | |||
| if (!sampler_) { | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, | |||
| "CacheLookupOp requires a sampler before it can be executed!"); | |||
| } | |||
| RETURN_IF_NOT_OK(RegisterResources()); | |||
| // Kick off the workers | |||
| RETURN_IF_NOT_OK( | |||
| tree_->LaunchWorkers(num_workers_, std::bind(&CacheLookupOp::WorkerEntry, this, std::placeholders::_1))); | |||
| // required task group sync after launching workers | |||
| TaskManager::FindMe()->Post(); | |||
| // We have to wait until the leaf op has handshake with us. | |||
| RETURN_IF_NOT_OK(leaf_op_wp_.Wait()); | |||
| RETURN_IF_NOT_OK(FetchSamplesToWorkers()); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheLookupOp::WorkerEntry(int32_t worker_id) { | |||
| TaskManager::FindMe()->Post(); | |||
| RETURN_IF_NOT_OK(FetchFromCache(worker_id)); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheLookupOp::ResetSampler() { return Status::OK(); } | |||
| Status CacheLookupOp::HandshakeRandomAccessOp(const RandomAccessOp *op) { | |||
| // We act like a sampler and as a dataset op. During handshake with leaf op, | |||
| // We must wait until the leaf op has indexed everything. | |||
| RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(op)); | |||
| // Now we notify the main thread handshake has finished. | |||
| leaf_op_wp_.Set(); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheLookupOp::InitSampler() { return Sampler::InitSampler(); } | |||
| void CacheLookupOp::Print(std::ostream &out, bool show_all) const { CacheBase::Print(out, show_all); } | |||
| Status CacheLookupOp::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| std::vector<row_id_type> cache_miss; | |||
| RETURN_IF_NOT_OK(keys_miss_.Pop(0, &cache_miss)); | |||
| // Ignore the case we have no cache miss, we can't return empty samples. | |||
| while (cache_miss.empty()) { | |||
| RETURN_IF_NOT_OK(keys_miss_.Pop(0, &cache_miss)); | |||
| } | |||
| // Special code for eoe | |||
| if (cache_miss.at(0) == eoe_row_id) { | |||
| *out_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||
| } else { | |||
| std::shared_ptr<Tensor> sample_ts; | |||
| RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ts, cache_miss.size())); | |||
| (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagNone); | |||
| auto idPtr = sample_ts->begin<int64_t>(); | |||
| for (auto i = 0; i < cache_miss.size(); ++i) { | |||
| *idPtr = cache_miss.at(i); | |||
| ++idPtr; | |||
| } | |||
| TensorRow row; | |||
| row.push_back(sample_ts); | |||
| (*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheLookupOp::RegisterResources() { | |||
| RETURN_IF_NOT_OK(CacheBase::RegisterResources()); | |||
| RETURN_IF_NOT_OK(leaf_op_wp_.Register(tree_->AllTasks())); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheLookupOp::ComputeColMap() { | |||
| // We don't know the column map at this point unless we contact the cache server | |||
| // to fetch the schema but the cache server may not have it at this point either. | |||
| // So we will just return OK and let MergeOp (our parent) to handle it. | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accept method for NodePass | |||
| Status CacheLookupOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(shared_from_base<CacheLookupOp>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,122 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_ | |||
| #define DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_ | |||
| #include <atomic> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "dataset/engine/datasetops/cache_base_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// \brief provides a memory/disk cache that acts as a save-point within a mappable dataset. | |||
| /// \note For non-mappable dataset, please see CacheOp | |||
| /// \see CacheOp | |||
| class CacheLookupOp : public CacheBase, public Sampler { | |||
| public: | |||
| class Builder { | |||
| public: | |||
| /// \brief Builder constructor. Creates the builder object. | |||
| /// \note No default args | |||
| Builder(); | |||
| /// Default destructor | |||
| ~Builder() = default; | |||
| /// Setter method. | |||
| /// \treturn Builder setter method returns reference to the builder. | |||
| Builder &SetNumWorkers(int32_t num_workers) { | |||
| build_num_workers_ = num_workers; | |||
| return *this; | |||
| } | |||
| /// Setter method. | |||
| /// \return Builder setter method returns reference to the builder. | |||
| Builder &SetOpConnectorSize(int32_t connector_size) { | |||
| build_op_connector_size_ = connector_size; | |||
| return *this; | |||
| } | |||
| /// Setter method. | |||
| /// \return Builder setter method returns reference to the builder. | |||
| Builder &SetClient(std::shared_ptr<CacheClient> cache_client) { | |||
| build_cache_client_ = cache_client; | |||
| return *this; | |||
| } | |||
| /// \brief Setter method. | |||
| /// \return Builder setter method returns reference to the builder. | |||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||
| build_sampler_ = std::move(sampler); | |||
| return *this; | |||
| } | |||
| /// \brief The builder "build" method creates the final object and does some init on it. | |||
| /// \param ptr The shared_ptr to the new CacheLookupOp object | |||
| /// \return Status | |||
| Status Build(std::shared_ptr<CacheLookupOp> *ptr); | |||
| private: | |||
| int32_t build_num_workers_; | |||
| int32_t rows_per_buffer_; | |||
| int32_t build_op_connector_size_; | |||
| std::shared_ptr<CacheClient> build_cache_client_; | |||
| std::shared_ptr<Sampler> build_sampler_; | |||
| // Check if the required parameters are set by the builder. | |||
| // \return Status The error code return | |||
| Status SanityCheck() const; | |||
| }; | |||
| /// \brief Constructor | |||
| /// \note It takes the same argument as the base class. | |||
| /// \see CacheBase | |||
| CacheLookupOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, | |||
| std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler) | |||
| : CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler), Sampler(*(sampler.get())) {} | |||
| ~CacheLookupOp() = default; | |||
| // As a parallel op, we override these two functions | |||
| Status operator()() override; | |||
| Status WorkerEntry(int32_t worker_id) override; | |||
| // As a sampler, we override the following functions | |||
| Status ResetSampler() override; | |||
| Status HandshakeRandomAccessOp(const RandomAccessOp *op) override; | |||
| Status InitSampler() override; | |||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| bool AllowCacheMiss() override { return true; } | |||
| std::string Name() const override { return "CacheLookupOp"; } | |||
| /// \brief Base-class override for NodePass visitor acceptor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| protected: | |||
| Status ComputeColMap() override; | |||
| private: | |||
| WaitPost leaf_op_wp_; | |||
| Status RegisterResources() override; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_ | |||
| @@ -0,0 +1,301 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <algorithm> | |||
| #include <functional> | |||
| #include <iomanip> | |||
| #include "dataset/core/config_manager.h" | |||
| #include "dataset/core/constants.h" | |||
| #include "dataset/core/global_context.h" | |||
| #include "dataset/engine/datasetops/cache_merge_op.h" | |||
| #include "dataset/engine/opt/pass.h" | |||
| #include "dataset/engine/execution_tree.h" | |||
| #include "dataset/util/task_manager.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| CacheMergeOp::~CacheMergeOp() = default; | |||
| void CacheMergeOp::Print(std::ostream &out, bool show_all) | |||
| const { // Always show the id and name as first line regardless if this is summary or detailed print | |||
| out << "(" << std::setw(2) << operator_id_ << ") <CacheMergeOp>:"; | |||
| if (!show_all) { | |||
| // Call the super class for displaying any common 1-liner info | |||
| ParallelOp::Print(out, show_all); | |||
| // Then show any custom derived-internal 1-liner info for this op | |||
| out << "\n"; | |||
| } else { | |||
| // Call the super class for displaying any common detailed info | |||
| ParallelOp::Print(out, show_all); | |||
| // Then show any custom derived-internal stuff | |||
| out << "\n\n"; | |||
| } | |||
| } | |||
| CacheMergeOp::CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners, | |||
| std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<Sampler> &sampler) | |||
| : ParallelOp(numWorkers, opConnectorSize, sampler), num_cleaners_(numCleaners), cache_client_(cache_client) {} | |||
| Status CacheMergeOp::operator()() { | |||
| // A queue of row id to let cleaner send cache miss rows to the cache server | |||
| // We don't want a small queue as this will block the parallel op workers. | |||
| // A row id is 8 byte integer. So bigger size doesn't consume a lot of memory. | |||
| io_que_ = std::make_unique<Queue<row_id_type>>(512); | |||
| RETURN_IF_NOT_OK(io_que_->Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK( | |||
| tree_->LaunchWorkers(num_workers_, std::bind(&CacheMergeOp::WorkerEntry, this, std::placeholders::_1))); | |||
| RETURN_IF_NOT_OK( | |||
| tree_->LaunchWorkers(num_workers_, std::bind(&CacheMergeOp::CacheMissWorkerEntry, this, std::placeholders::_1))); | |||
| // One dedicated thread to move TensorRow from the pool to the cache server | |||
| for (auto i = 0; i < num_cleaners_; ++i) { | |||
| RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("Cleaner", std::bind(&CacheMergeOp::Cleaner, this))); | |||
| } | |||
| TaskManager::FindMe()->Post(); | |||
| return Status::OK(); | |||
| } | |||
| // Each parallel worker will pop from the CacheHit stream. If there is a missing TensorRow, we will wait | |||
| // until it shows up in the pool. | |||
| Status CacheMergeOp::WorkerEntry(int32_t worker_id) { | |||
| TaskManager::FindMe()->Post(); | |||
| std::shared_ptr<DatasetOp> cache_hit_stream = child_[kCacheHitChildIdx]; | |||
| std::unique_ptr<DataBuffer> db_ptr; | |||
| RETURN_IF_NOT_OK(cache_hit_stream->GetNextBuffer(&db_ptr, worker_id)); | |||
| while (!db_ptr->eof()) { | |||
| if (db_ptr->eoe()) { | |||
| RETURN_IF_NOT_OK(EoeReceived(worker_id)); | |||
| db_ptr.reset(); | |||
| RETURN_IF_NOT_OK(cache_hit_stream->GetNextBuffer(&db_ptr, worker_id)); | |||
| } else { | |||
| // See if there is any missing row | |||
| auto tbl = std::make_unique<TensorQTable>(); | |||
| while (db_ptr->NumRows() > 0) { | |||
| TensorRow row; | |||
| RETURN_IF_NOT_OK(db_ptr->PopRow(&row)); | |||
| if (row.empty()) { | |||
| auto row_id = row.getId(); | |||
| TensorRowRequest *rq = nullptr; | |||
| RETURN_IF_NOT_OK(GetRq(row_id, &rq)); | |||
| // Block until the row shows up in the pool. | |||
| RETURN_IF_NOT_OK(rq->Wait(&row)); | |||
| } | |||
| tbl->push_back(std::move(row)); | |||
| } | |||
| db_ptr->set_tensor_table(std::move(tbl)); | |||
| RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db_ptr))); | |||
| RETURN_IF_NOT_OK(cache_hit_stream->GetNextBuffer(&db_ptr, worker_id)); | |||
| } | |||
| } | |||
| RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db_ptr))); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) { | |||
| TaskManager::FindMe()->Post(); | |||
| // We will simply pop TensorRow from the stream and insert them into the pool and | |||
| // wake up any worker that is awaiting on the missing TensorRow. | |||
| // If we see an eoe, ignore it. For eof, we exit. | |||
| std::shared_ptr<DatasetOp> cache_missing_stream = child_[kCacheMissChildIdx]; | |||
| // Before we start, cache the schema at the server. Pick one of the workers | |||
| // do it. The schema should have been done at prepare time. | |||
| if (workerId == 0) { | |||
| RETURN_IF_NOT_OK(cache_client_->CacheSchema(column_name_id_map())); | |||
| } | |||
| std::unique_ptr<DataBuffer> db_ptr; | |||
| RETURN_IF_NOT_OK(cache_missing_stream->GetNextBuffer(&db_ptr, workerId)); | |||
| while (!db_ptr->eof()) { | |||
| if (db_ptr->eoe()) { | |||
| // Ignore it. | |||
| MS_LOG(DEBUG) << "Ignore eoe"; | |||
| } else { | |||
| while (db_ptr->NumRows() > 0) { | |||
| TensorRow row; | |||
| RETURN_IF_NOT_OK(db_ptr->PopRow(&row)); | |||
| row_id_type row_id = row.getId(); | |||
| if (row_id < 0) { | |||
| std::string errMsg = "Expect positive row id: " + std::to_string(row_id); | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| TensorRowRequest *rq = nullptr; | |||
| RETURN_IF_NOT_OK(GetRq(row_id, &rq)); | |||
| rq->WakeUpAny(std::move(row)); | |||
| // Let the cleaner to flush out this row (async) to the cache server. | |||
| RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id)); | |||
| } | |||
| } | |||
| RETURN_IF_NOT_OK(cache_missing_stream->GetNextBuffer(&db_ptr, workerId)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheMergeOp::Cleaner() { | |||
| TaskManager::FindMe()->Post(); | |||
| while (true) { | |||
| row_id_type row_id; | |||
| RETURN_IF_NOT_OK(io_que_->PopFront(&row_id)); | |||
| if (row_id < 0) { | |||
| break; | |||
| } | |||
| TensorRowRequest *rq = nullptr; | |||
| RETURN_IF_NOT_OK(GetRq(row_id, &rq)); | |||
| if (rq->GetState() == TensorRowRequest::State::kClean) { | |||
| // If already flushed, move on to the next one. | |||
| continue; | |||
| } | |||
| TensorRow row; | |||
| RETURN_IF_NOT_OK(rq->Release(&row)); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!row.empty(), "Programming error"); | |||
| Status rc = cache_client_->WriteRow(row); | |||
| // Bad rc should not bring down the pipeline | |||
| if (rc.IsError()) { | |||
| MS_LOG(WARNING) << "Cache not successful." << rc.ToString(); | |||
| } | |||
| rq->SetState(TensorRowRequest::State::kClean); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheMergeOp::GetRq(row_id_type row_id, CacheMergeOp::TensorRowRequest **out) { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| std::unique_lock<std::mutex> lck(mux_); | |||
| auto it = cache_miss_map_.find(row_id); | |||
| if (it != cache_miss_map_.end()) { | |||
| *out = it->second.GetMutablePointer(); | |||
| } else { | |||
| // We will create a new one. | |||
| auto alloc = Services::GetAllocator<TensorRowRequest>(); | |||
| auto r = cache_miss_map_.emplace(row_id, MemGuard<TensorRowRequest, Allocator<TensorRowRequest>>(alloc)); | |||
| if (r.second) { | |||
| auto &mem = r.first->second; | |||
| RETURN_IF_NOT_OK(mem.allocate(1, row_id)); | |||
| *out = mem.GetMutablePointer(); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Map insert fail."); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheMergeOp::PrepareNodePostAction() { // Run any common code from super class first before adding our own | |||
| // specific logic | |||
| CHECK_FAIL_RETURN_UNEXPECTED(child_.size() == 2, "Incorrect number of children"); | |||
| RETURN_IF_NOT_OK(ParallelOp::PrepareNodePostAction()); | |||
| // Get the computed check sum from all ops in the cache miss class | |||
| uint32_t cache_crc = DatasetOp::GenerateCRC(child_[kCacheMissChildIdx]); | |||
| // This is a mappable cache op so the id's need to be generated. | |||
| // Construct the cache | |||
| const bool generate_ids = false; | |||
| Status rc = cache_client_->CreateCache(cache_crc, generate_ids); | |||
| if (rc.get_code() == StatusCode::kDuplicateKey) { | |||
| // We are told the cache has been created already. | |||
| MS_LOG(INFO) << "Cache created already"; | |||
| rc = Status::OK(); | |||
| } | |||
| RETURN_IF_NOT_OK(rc); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheMergeOp::ComputeColMap() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(child_[kCacheMissChildIdx] != nullptr, "Cache miss stream empty"); | |||
| if (column_name_id_map().empty()) { | |||
| column_name_id_map_ = child_[kCacheMissChildIdx]->column_name_id_map(); | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!column_name_id_map().empty(), "No column map detected"); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheMergeOp::TensorRowRequest::Wait(TensorRow *out) { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| // Block until the missing row is in the pool. | |||
| RETURN_IF_NOT_OK(use_count_.P()); | |||
| std::unique_lock<std::mutex> lck(dq_mux_); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!row_.empty(), "Programming error"); | |||
| *out = std::move(row_.front()); | |||
| row_.pop_front(); | |||
| return Status::OK(); | |||
| } | |||
| void CacheMergeOp::TensorRowRequest::WakeUpAny(TensorRow &&row) { | |||
| std::unique_lock<std::mutex> lck(dq_mux_); | |||
| // Technically number of this row shows up in the cache miss stream is equal to the number | |||
| // of P() call. However the cleaner wants it too. So we need an extra copy. | |||
| if (GetState() == State::kEmpty) { | |||
| // We will do a deep copy | |||
| for (auto &ts : row) { | |||
| auto out_ts = std::make_shared<Tensor>(ts->shape(), ts->type(), ts->GetBuffer(), ts->SizeInBytes()); | |||
| cleaner_copy_.push_back(out_ts); | |||
| } | |||
| cleaner_copy_.setId(row.getId()); | |||
| // Change the state to dirty | |||
| SetState(State::kDirty); | |||
| } | |||
| row_.push_back(std::move(row)); | |||
| // Bump up the use count by 1. This wake up any parallel worker which is waiting | |||
| // for this row. | |||
| use_count_.V(); | |||
| } | |||
| Status CacheMergeOp::TensorRowRequest::Release(TensorRow *out) { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| // We are not holding any mutex here because the cleaner isn't really touching the deque row_. | |||
| // In case we have multiple cleaners and they all see the copy, only one of them will | |||
| // get it. | |||
| auto expected = State::kDirty; | |||
| if (st_.compare_exchange_strong(expected, State::kClean)) { | |||
| *out = std::move(cleaner_copy_); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Builder constructor. Creates the builder object. | |||
| CacheMergeOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) { | |||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||
| build_num_workers_ = cfg->num_parallel_workers(); | |||
| build_op_connector_size_ = cfg->op_connector_size(); | |||
| build_num_cleaners_ = 1; | |||
| } | |||
| // Check if the required parameters are set by the builder. | |||
| Status CacheMergeOp::Builder::SanityCheck() const { | |||
| if (build_cache_client_ == nullptr) { | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CacheMergeOp requires a CacheClient"); | |||
| } | |||
| // Make sure the cache client has a valid session | |||
| if (!build_cache_client_->session_id()) { | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, | |||
| "Cache client for CacheMergeOp is missing session id"); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // The builder "build" method creates the final object and does some init on it | |||
| Status CacheMergeOp::Builder::Build(std::shared_ptr<CacheMergeOp> *ptr) { | |||
| RETURN_IF_NOT_OK(SanityCheck()); | |||
| *ptr = std::make_shared<CacheMergeOp>(build_num_workers_, build_op_connector_size_, build_num_cleaners_, | |||
| build_cache_client_, build_sampler_); | |||
| return Status::OK(); | |||
| } | |||
| // Pre-Visitor accept method for NodePass | |||
| Status CacheMergeOp::PreAccept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call the pre-visitation | |||
| return p->PreRunOnNode(shared_from_base<CacheMergeOp>(), modified); | |||
| } | |||
| // Visitor accept method for NodePass | |||
| Status CacheMergeOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(shared_from_base<CacheMergeOp>(), modified); | |||
| } | |||
| Status CacheMergeOp::EoeReceived(int32_t worker_id) { | |||
| // If we are in a repeat path, send the eoe up. | |||
| // Otherwise ignore it. | |||
| if (BitTest(op_ctrl_flags_, kDeOpRepeated)) { | |||
| return DatasetOp::EoeReceived(worker_id); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,196 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_ | |||
| #define DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_ | |||
| #include <atomic> | |||
| #include <deque> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <mutex> | |||
| #include <string> | |||
| #include <utility> | |||
| #include "dataset/core/tensor_row.h" | |||
| #include "dataset/engine/cache/cache_client.h" | |||
| #include "dataset/engine/datasetops/parallel_op.h" | |||
| #include "dataset/engine/dataset_iterator.h" | |||
| #include "dataset/util/queue.h" | |||
| #include "dataset/util/semaphore.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// \brief Provides method to merge two streams (one from CacheLookup and one from cache miss stream) into one single | |||
| /// stream | |||
| class CacheMergeOp : public ParallelOp { | |||
| public: | |||
| // Some handshake structures among the main thread, cleaner threads and parallel op threads. | |||
| class TensorRowRequest { | |||
| public: | |||
| enum class State : uint8_t { | |||
| kEmpty = 0, // No row in the deque | |||
| kDirty = 1, // Cleaner hasn't flushed it to the cache server yet. | |||
| kClean = 2 // The row has been flushed already. | |||
| }; | |||
| explicit TensorRowRequest(row_id_type id) : st_(State::kEmpty), use_count_(0) {} | |||
| ~TensorRowRequest() = default; | |||
| State GetState() const { return st_; } | |||
| void SetState(State newState) { st_ = newState; } | |||
| Status Wait(TensorRow *out); | |||
| void WakeUpAny(TensorRow &&row); | |||
| Status Release(TensorRow *out); | |||
| private: | |||
| std::mutex dq_mux_; | |||
| std::atomic<State> st_; | |||
| Semaphore use_count_; | |||
| std::deque<TensorRow> row_; | |||
| TensorRow cleaner_copy_; | |||
| }; | |||
| constexpr static int kCacheHitChildIdx = 0; // Cache hit stream | |||
| constexpr static int kCacheMissChildIdx = 1; // Cache miss stream | |||
| /// \brief The nested builder class inside of the CacheMergeOp is used to help manage all of | |||
| /// the arguments for constructing it. Use the builder by setting each argument | |||
| /// with the provided set methods, and then finally call the build method to execute | |||
| /// the actual construction. | |||
| class Builder { | |||
| public: | |||
| /// Builder constructor. Creates the builder object. | |||
| /// \note No default args | |||
| Builder(); | |||
| /// Default destructor | |||
| ~Builder() = default; | |||
| /// Setter method. | |||
| /// \return Builder setter method returns reference to the builder. | |||
| Builder &SetNumWorkers(int32_t num_workers) { | |||
| build_num_workers_ = num_workers; | |||
| return *this; | |||
| } | |||
| /// Setter method. | |||
| /// \return Builder setter method returns reference to the builder. | |||
| Builder &SetOpConnectorSize(int32_t connector_size) { | |||
| build_op_connector_size_ = connector_size; | |||
| return *this; | |||
| } | |||
| /// Setter method. | |||
| /// \return Builder setter method returns reference to the builder. | |||
| Builder &SetClient(std::shared_ptr<CacheClient> cache_client) { | |||
| build_cache_client_ = cache_client; | |||
| return *this; | |||
| } | |||
| /// \brief Setter method | |||
| /// \param sampler | |||
| /// \return Builder setter method returns reference to the builder. | |||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||
| build_sampler_ = std::move(sampler); | |||
| return *this; | |||
| } | |||
| /// \brief Setter method | |||
| /// \param num_cleaners | |||
| /// \return Builder setter method returns reference to the builder. | |||
| Builder &SetNumCleaner(int32_t num_cleaners) { | |||
| build_num_cleaners_ = num_cleaners; | |||
| return *this; | |||
| } | |||
| /// The builder "build" method creates the final object and does some init on it. | |||
| /// \param ptr The shared_ptr to the new CacheMergeOp object | |||
| /// \return Status | |||
| Status Build(std::shared_ptr<CacheMergeOp> *ptr); | |||
| private: | |||
| int32_t build_num_workers_; | |||
| int32_t build_op_connector_size_; | |||
| int32_t build_num_cleaners_; | |||
| std::shared_ptr<CacheClient> build_cache_client_; | |||
| std::shared_ptr<Sampler> build_sampler_; | |||
| /// Check if the required parameters are set by the builder. | |||
| /// \return Status The error code return | |||
| Status SanityCheck() const; | |||
| }; | |||
| /// \brief Constructor | |||
| /// \param numWorkers Number of parallel workers as a derived class of ParallelOp | |||
| /// \param opConnector Size Connector size as a derived class of ParallelOp | |||
| /// \param numCleaners Number of cleaners to move cache miss rows into the cache server | |||
| /// \param cache_client CacheClient to commmunicate with the Cache server | |||
| /// \param sampler as a derived class of ParallelOp | |||
| CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners, | |||
| std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<Sampler> &sampler); | |||
| ~CacheMergeOp(); | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| friend std::ostream &operator<<(std::ostream &out, const CacheMergeOp &mo) { | |||
| mo.Print(out, false); | |||
| return out; | |||
| } | |||
| /// \brief Master thread responsible to spawn all the necessary worker threads for the two streams and | |||
| /// the threads for the cleaners. | |||
| /// \return | |||
| Status operator()() override; | |||
| /// \brief Entry function for worker thread that fetch rows from CacheLookupOp | |||
| /// \param workerId | |||
| /// \return Status object | |||
| Status WorkerEntry(int32_t workerId) override; | |||
| Status PrepareNodePostAction() override; | |||
| /// \brief Entry function for worker thread that fetch rows from the cache miss stream | |||
| /// \param workerId | |||
| /// \return Status object | |||
| Status CacheMissWorkerEntry(int32_t workerId); | |||
| Status GetRq(row_id_type row_id, TensorRowRequest **); | |||
| /// \brief Base-class override for NodePass pre-visit acceptor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status PreAccept(NodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for NodePass visitor acceptor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for eoe handling | |||
| /// \param worker_id | |||
| /// \return Status object | |||
| Status EoeReceived(int32_t worker_id) override; | |||
| protected: | |||
| Status ComputeColMap() override; | |||
| private: | |||
| std::mutex mux_; | |||
| std::map<row_id_type, MemGuard<TensorRowRequest, Allocator<TensorRowRequest>>> cache_miss_map_; | |||
| std::unique_ptr<Queue<row_id_type>> io_que_; | |||
| std::shared_ptr<CacheClient> cache_client_; | |||
| int32_t num_cleaners_; | |||
| /// \brief These are the entry functions for the cleaner threads. Each cleaner is responsible for | |||
| /// moving cache miss TensorRow into the CacheServer. | |||
| /// \return Status object | |||
| Status Cleaner(); | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_ | |||
| @@ -0,0 +1,219 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/engine/datasetops/cache_op.h" | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "dataset/core/config_manager.h" | |||
| #include "dataset/core/constants.h" | |||
| #include "dataset/core/global_context.h" | |||
| #include "dataset/engine/datasetops/repeat_op.h" | |||
| #include "dataset/engine/data_buffer.h" | |||
| #include "dataset/engine/execution_tree.h" | |||
| #include "dataset/engine/opt/pass.h" | |||
| #include "dataset/util/task_manager.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Builder constructor. Creates the builder object. | |||
| CacheOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) { | |||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||
| build_num_workers_ = cfg->num_parallel_workers(); | |||
| rows_per_buffer_ = cfg->rows_per_buffer(); | |||
| build_op_connector_size_ = cfg->op_connector_size(); | |||
| } | |||
| // Check if the required parameters are set by the builder. | |||
| Status CacheOp::Builder::SanityCheck() const { | |||
| if (build_cache_client_ == nullptr) { | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CacheOp requires a CacheClient"); | |||
| } | |||
| // Make sure the cache client has a valid session | |||
| if (!build_cache_client_->session_id()) { | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cache client for CacheOp is missing session id"); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // The builder "build" method creates the final object and does some init on it | |||
| Status CacheOp::Builder::Build(std::shared_ptr<CacheOp> *ptr) { | |||
| RETURN_IF_NOT_OK(SanityCheck()); | |||
| *ptr = std::make_shared<CacheOp>(build_num_workers_, build_op_connector_size_, rows_per_buffer_, build_cache_client_, | |||
| build_sampler_); | |||
| RETURN_IF_NOT_OK((*ptr)->InitCache()); | |||
| return Status::OK(); | |||
| } | |||
| // Constructor of CacheOp | |||
| CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, | |||
| std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler) | |||
| : CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler), | |||
| num_guys_in_(0), | |||
| phase_(Phase::kBuildPhase) {} | |||
| // Destructor | |||
| CacheOp::~CacheOp() = default; | |||
| // Private function for cache setup/init work just after construction | |||
| Status CacheOp::InitCache() { return Status::OK(); } | |||
| // This class functor will provide the master loop that drives the logic for performing the work | |||
| Status CacheOp::operator()() { | |||
| if (!sampler_) { | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, | |||
| "CacheOp requires a sampler before it can be executed!"); | |||
| } | |||
| RETURN_IF_NOT_OK(RegisterResources()); | |||
| // Kick off the workers | |||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CacheOp::WorkerEntry, this, std::placeholders::_1))); | |||
| // required task group sync after launching workers | |||
| TaskManager::FindMe()->Post(); | |||
| // Wait for the workers to finish caching the rows. | |||
| RETURN_IF_NOT_OK(WaitForCachingAllRows()); | |||
| RETURN_IF_NOT_OK(FetchSamplesToWorkers()); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheOp::CacheAllRows(int32_t worker_id) { | |||
| // If the current phase is to fill the cache, do it then. | |||
| if (phase_ == Phase::kBuildPhase) { | |||
| // We will take the chance to cache the schema at the server. | |||
| // Just do it once and pick one worker to do it. | |||
| if (worker_id == 0) { | |||
| RETURN_IF_NOT_OK(cache_client_->CacheSchema(column_name_id_map())); | |||
| } | |||
| MS_LOG(INFO) << "CacheOp first epoch SAVE mode started. Worker: " << worker_id; | |||
| // SAVE mode loop | |||
| std::unique_ptr<DataBuffer> db_ptr; | |||
| RETURN_IF_NOT_OK(this->GetNextInput(&db_ptr, worker_id, 0)); | |||
| while (!db_ptr->eof()) { | |||
| if (!db_ptr->eoe()) { | |||
| RETURN_IF_NOT_OK(cache_client_->WriteBuffer(std::move(db_ptr))); | |||
| } else { | |||
| // In a repeat-over-cache scenario, any of the "real" leaf operators below us have been set up | |||
| // as non-repeating leaf ops. As such, they only do one epoch and then quit. Since we got the | |||
| // the eoe to indicate the end of the epoch, we should next expect to get the eof. | |||
| // Drain this eof so that we don't leave it sitting there on a connector that we'll never fetch | |||
| // from again. | |||
| RETURN_IF_NOT_OK(this->GetNextInput(&db_ptr, worker_id, 0)); | |||
| if (!db_ptr->eof()) { | |||
| RETURN_STATUS_UNEXPECTED("Cache op expects to get an eof after eoe from child."); | |||
| } | |||
| } | |||
| RETURN_IF_NOT_OK(this->GetNextInput(&db_ptr, worker_id, 0)); | |||
| } | |||
| } | |||
| // Let the main guy know we are done. | |||
| auto last_guy_in = num_guys_in_.fetch_add(1); | |||
| if ((last_guy_in + 1) == num_workers_) { | |||
| rows_cache_done_.Set(); | |||
| } else { | |||
| // Let's do a sync up here. | |||
| RETURN_IF_NOT_OK(rows_cache_done_.Wait()); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheOp::WaitForCachingAllRows() { | |||
| // Wait for the workers to finish caching the rows. | |||
| RETURN_IF_NOT_OK(rows_cache_done_.Wait()); | |||
| // Move from build phase to fetch phase if we are the one to fill the cache | |||
| if (phase_ == Phase::kBuildPhase) { | |||
| RETURN_IF_NOT_OK(cache_client_->BuildPhaseDone()); | |||
| // Move to the next phase | |||
| phase_ = Phase::kFetchPhase; | |||
| } | |||
| // Get statistics from the server, and if we are not the one to create the cache, | |||
| // wait until the state changed from build phase to fetch base. | |||
| CacheClient::ServiceStat stat{}; | |||
| bool BuildPhaseDone = true; | |||
| do { | |||
| RETURN_IF_NOT_OK(cache_client_->GetStat(&stat)); | |||
| BuildPhaseDone = stat.cache_service_state == static_cast<uint8_t>(CacheService::State::kFetchPhase); | |||
| if (!BuildPhaseDone) { | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(100)); | |||
| } | |||
| } while (!BuildPhaseDone); | |||
| const row_id_type min_key = stat.min_row_id; | |||
| const row_id_type max_key = stat.max_row_id; | |||
| num_rows_ = max_key - min_key + 1; | |||
| MS_LOG(INFO) << "Number of rows cached: " << num_rows_; | |||
| MS_LOG(INFO) << "Number of rows cached in memory : " << stat.num_mem_cached; | |||
| MS_LOG(INFO) << "Number of rows spilled to disk : " << stat.num_disk_cached; | |||
| // Now all rows are cached and we have done a sync point check up. Next phase is | |||
| // is pick up fetch input from sampler and pass up to the caller. | |||
| RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheOp::WorkerEntry(int32_t worker_id) { | |||
| TaskManager::FindMe()->Post(); | |||
| RETURN_IF_NOT_OK(CacheAllRows(worker_id)); | |||
| RETURN_IF_NOT_OK(FetchFromCache(worker_id)); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheOp::RegisterResources() { | |||
| RETURN_IF_NOT_OK(CacheBase::RegisterResources()); | |||
| RETURN_IF_NOT_OK(rows_cache_done_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(keys_miss_.Register(tree_->AllTasks())); | |||
| return Status::OK(); | |||
| } | |||
| // Base-class override for setting specific CacheOp configurations. This code will be called | |||
| // during the execution tree prepare phase BEFORE traversing down to child operators. | |||
| uint32_t CacheOp::PrepareFlags() const { return ExecutionTree::kDePrepCache; } | |||
| // Base-class override for special eoe handler. | |||
| // CacheOp must override this because it shall not perform default handling of eoe. Instead | |||
| // the CacheOp manages actions related to the end of the epoch. | |||
| Status CacheOp::EoeReceived(int32_t worker_id) { | |||
| state_ = OpState::kDeOpIdle; | |||
| return Status::OK(); | |||
| } | |||
| // Base-class override for handling cases when an eof is received. | |||
| Status CacheOp::EofReceived(int32_t worker_id) { | |||
| // eofReceived is overloaded because we want to manually handle this eof. | |||
| // Specifically, the default behaviour is to pack it and flow it up to the next connection. | |||
| // In this case, we want a no-op behaviour so that we can perform correct action. | |||
| return Status::OK(); | |||
| } | |||
| // Pre-Visitor accept method for NodePass | |||
| Status CacheOp::PreAccept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call the pre-visitation | |||
| return p->PreRunOnNode(shared_from_base<CacheOp>(), modified); | |||
| } | |||
| // Visitor accept method for NodePass | |||
| Status CacheOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(shared_from_base<CacheOp>(), modified); | |||
| } | |||
| // A public wrapper for creating the cache through the client | |||
| Status CacheOp::CreateCache(uint32_t cache_crc) { | |||
| // This is a non-mappable cache op so the id's need to be generated. | |||
| // Construct the cache | |||
| const bool generate_ids = true; | |||
| Status rc = cache_client_->CreateCache(cache_crc, generate_ids); | |||
| if (rc.get_code() == StatusCode::kDuplicateKey) { | |||
| // We are told the cache has been created already. So we skip the build phase. | |||
| phase_ = Phase::kFetchPhase; | |||
| rc = Status::OK(); | |||
| } | |||
| RETURN_IF_NOT_OK(rc); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,168 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_DATASETOPS_CACHE_OP_H_ | |||
| #define DATASET_ENGINE_DATASETOPS_CACHE_OP_H_ | |||
| #include <atomic> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <memory> | |||
| #include "dataset/engine/datasetops/cache_base_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// \brief CacheOp provides a memory/disk cache that acts as a save-point within a non-mappable dataset. | |||
| /// \note For mappable dataset, please see CacheLookupOp. | |||
| /// \see CacheLookupOp | |||
| class CacheOp : public CacheBase, public RandomAccessOp { | |||
| public: | |||
| // This CacheOp is for non-mappable case where it is divided into two phases. | |||
| // The first phase is we cache all the rows from the child (and let the cache server | |||
| // assigns row id). No read access in the first phase. Once the cache is fully built, | |||
| // we switch to second phase and fetch requests from the sampler. | |||
| enum class Phase : uint8_t { kBuildPhase = 0, kFetchPhase = 1 }; | |||
| /// \brief The nested builder class inside of the CacheOp is used to help manage all of | |||
| /// the arguments for constructing it. Use the builder by setting each argument | |||
| /// with the provided set methods, and then finally call the build method to execute | |||
| /// the actual construction. | |||
| class Builder { | |||
| public: | |||
| // Builder constructor. Creates the builder object. | |||
| // @note No default args | |||
| // @return This is a constructor. | |||
| Builder(); | |||
| // Default destructor | |||
| ~Builder() = default; | |||
| /// \brief Setter method. | |||
| /// \return Builder setter method returns reference to the builder. | |||
| Builder &SetNumWorkers(int32_t num_workers) { | |||
| build_num_workers_ = num_workers; | |||
| return *this; | |||
| } | |||
| /// \brief Setter method. | |||
| /// \return Builder setter method returns reference to the builder. | |||
| Builder &SetOpConnectorSize(int32_t connector_size) { | |||
| build_op_connector_size_ = connector_size; | |||
| return *this; | |||
| } | |||
| /// Setter method. | |||
| /// \return Builder setter method returns reference to the builder. | |||
| Builder &SetClient(std::shared_ptr<CacheClient> cache_client) { | |||
| build_cache_client_ = cache_client; | |||
| return *this; | |||
| } | |||
| /// \brief Setter method | |||
| /// \param rows_per_buffer | |||
| /// \return Builder setter method returns reference to the builder. | |||
| Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { | |||
| rows_per_buffer_ = rows_per_buffer; | |||
| return *this; | |||
| } | |||
| /// \brief Setter method | |||
| /// \param sampler | |||
| /// \return Builder setter method returns reference to the builder. | |||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||
| build_sampler_ = std::move(sampler); | |||
| return *this; | |||
| } | |||
| /// \brief The builder "build" method creates the final object and does some init on it. | |||
| /// \param ptr The shared_ptr to the new CacheOp object | |||
| /// \return Status | |||
| Status Build(std::shared_ptr<CacheOp> *ptr); | |||
| private: | |||
| int32_t build_num_workers_; | |||
| int32_t rows_per_buffer_; | |||
| int32_t build_op_connector_size_; | |||
| std::shared_ptr<CacheClient> build_cache_client_; | |||
| std::shared_ptr<Sampler> build_sampler_; | |||
| /// \brief Check if the required parameters are set by the builder. | |||
| /// \return Status The error code return | |||
| Status SanityCheck() const; | |||
| }; | |||
| /// \brief Constructor of CacheOp | |||
| /// \note The builder class should be used to call it. | |||
| /// \param num_workers The number of worker threads. | |||
| /// \param op_connector_size The size of each queue in the connector. | |||
| CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, | |||
| std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler); | |||
| // Destructor | |||
| ~CacheOp(); | |||
| /// \brief Base-class override for setting specific CacheOp configurations. This code will be called | |||
| /// during the execution tree prepare phase BEFORE traversing down to child operators. | |||
| uint32_t PrepareFlags() const override; | |||
| /// \brief Base-class override for special eoe handler. | |||
| /// CacheOp must override this because it shall not perform default handling of eoe. Instead | |||
| /// the CacheOp manages actions related to the end of the epoch. | |||
| /// \return Status - The error code return | |||
| Status EoeReceived(int32_t worker_id) override; | |||
| /// \brief Base-class override for NodePass pre-visit acceptor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status PreAccept(NodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for NodePass visitor acceptor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for handling cases when an eof is received. | |||
| /// \param worker_id - The worker id | |||
| /// \return Status - The error code return | |||
| Status EofReceived(int32_t worker_id) override; | |||
| Status operator()() override; | |||
| Status WorkerEntry(int32_t worker_id) override; | |||
| /// \brief Base-class override for handling cases if we allow cache miss | |||
| bool AllowCacheMiss() override { return false; } | |||
| /// \brief Base-class override for the name of this operator | |||
| std::string Name() const override { return "CacheOp"; } | |||
| /// \brief A public wrapper for creating the cache through the client | |||
| /// \param[in] cache_crc The crc that identifies the cache | |||
| /// \see cache_pass.cc | |||
| /// \return Status return code | |||
| Status CreateCache(uint32_t cache_crc); | |||
| private: | |||
| WaitPost rows_cache_done_; | |||
| std::atomic<int64_t> num_guys_in_; | |||
| Phase phase_; | |||
| /// \brief The main thread will wait until all the rows are cached and will start the handshake with the sampler. | |||
| /// \return Status object | |||
| Status WaitForCachingAllRows(); | |||
| /// \brief For non-mappable dataset, there is a build phase where we cache all the rows. | |||
| /// \return Status object | |||
| Status CacheAllRows(int32_t worker_id); | |||
| Status RegisterResources() override; | |||
| /// \brief Private function for cache setup/init work just after construction | |||
| /// \return Status The error code return | |||
| Status InitCache(); | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_DATASETOPS_CACHE_OP_H_ | |||
| @@ -61,46 +61,39 @@ void ConcatOp::Print(std::ostream &out, bool show_all) const { | |||
| Status ConcatOp::operator()() { | |||
| // The children_num_ parameter needs to be put here | |||
| children_num_ = static_cast<int32_t>(child_.size()); | |||
| TaskManager::FindMe()->Post(); | |||
| std::unique_ptr<DataBuffer> buf; | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf)); | |||
| int eof_count = 0; | |||
| while (eof_count != children_num_) { | |||
| while (eof_count == 0) { | |||
| for (int i = 0; i < children_num_; i++) { | |||
| // 1. Throw the eof buffer when meet it | |||
| if (buf->eof() || buf->eoe()) { | |||
| RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf)); | |||
| // 1. Read the first buffer | |||
| RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf)); | |||
| if (buf->eof()) { | |||
| eof_count++; | |||
| continue; | |||
| } | |||
| // 2. Do verification as for column name, column data type and rank of column data | |||
| RETURN_IF_NOT_OK(Verify(i, buf)); | |||
| if (!buf->eoe()) { | |||
| RETURN_IF_NOT_OK(Verify(i, buf)); | |||
| } | |||
| // 3. Put the data into output_connector | |||
| while (!buf->eoe() && !buf->eof()) { | |||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buf))); | |||
| RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf)); | |||
| } | |||
| // 4. Throw the eoe buffer when meet it | |||
| if (buf->eoe() && (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat))) { | |||
| RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf)); | |||
| } | |||
| // 5. Add eoe buffer after get buffer from all child | |||
| if (i == (children_num_ - 1)) { | |||
| auto eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); | |||
| } | |||
| if (buf->eof()) { | |||
| eof_count++; | |||
| } | |||
| } | |||
| // 4. Add eoe buffer after get buffer from all child | |||
| if (eof_count == 0) { | |||
| auto eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); | |||
| } | |||
| } | |||
| // 6. Add eof buffer in the end manually | |||
| CHECK_FAIL_RETURN_UNEXPECTED(eof_count == children_num_, | |||
| "Something went wrong, eof count does not match the number of children."); | |||
| // 5. Add eof buffer in the end manually | |||
| MS_LOG(DEBUG) << "Add the eof buffer manualy in the end."; | |||
| auto eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF); | |||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); | |||
| return Status::OK(); | |||
| } | |||
| @@ -126,12 +119,6 @@ Status ConcatOp::Verify(int32_t id, const std::unique_ptr<DataBuffer> &buf) { | |||
| return Status::OK(); | |||
| } | |||
| Status ConcatOp::PrepareNodePostAction() { | |||
| RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction()); | |||
| tree_->AddToEOEOpStack(shared_from_this()); | |||
| return Status::OK(); | |||
| } | |||
| // We need to overwrite the super class ComputeColMap here because the number of children is more than 1. | |||
| Status ConcatOp::ComputeColMap() { | |||
| if (column_name_id_map_.empty()) { | |||
| @@ -75,12 +75,6 @@ class ConcatOp : public PipelineOp { | |||
| // @return Status - The error code return | |||
| Status operator()() override; | |||
| // During tree prepare phase, operators may have specific post-operations to perform depending on | |||
| // their role. | |||
| // @notes Derived versions of this function should always call it's superclass version first | |||
| // before providing their own implementations. | |||
| Status PrepareNodePostAction() override; | |||
| // Op name getter | |||
| // @return Name of the current Op | |||
| std::string Name() const override { return "ConcatOp"; } | |||
| @@ -153,16 +153,38 @@ Status DatasetOp::Remove() { | |||
| } | |||
| } | |||
| // Finally, clear "this" op's parent and child pointers since we have just | |||
| // disconnected it from the tree and invalidate it's fields. | |||
| child_.clear(); | |||
| parent_.clear(); | |||
| operator_id_ = kInvalidOperatorId; | |||
| tree_ = nullptr; | |||
| return Status::OK(); | |||
| } | |||
| // Getter function to get a shared pointer to our childAdds a operator to become our child. | |||
| // Getter function to get a shared pointer to our child | |||
| std::shared_ptr<DatasetOp> DatasetOp::child(int32_t child_index) const { | |||
| std::shared_ptr<DatasetOp> return_op = nullptr; | |||
| if (child_.empty()) { | |||
| return return_op; | |||
| } | |||
| MS_ASSERT(child_index < static_cast<int>(child_.size())); | |||
| // Return a shared pointer | |||
| return child_[child_index]; | |||
| } | |||
| // Getter function to get the parent pointer | |||
| void DatasetOp::Parent(DatasetOp **parent, int32_t parent_index) const { | |||
| if (parent_.empty()) { | |||
| // common case if this is a root node | |||
| *parent = nullptr; | |||
| } else { | |||
| MS_ASSERT(parent_index < static_cast<int>(parent_.size())); | |||
| *parent = parent_[parent_index]; | |||
| } | |||
| } | |||
| // Creates the connector within this operator | |||
| void DatasetOp::CreateConnector(int32_t num_producers, int32_t num_consumers) { | |||
| MS_LOG(DEBUG) << "Creating connector in tree operator: " << operator_id_ << ". Producer: " << num_producers | |||
| @@ -264,19 +286,11 @@ Status DatasetOp::EofReceived(int32_t worker_id) { | |||
| // During tree prepare phase, operators may have specific pre-operations to perform depending on | |||
| // their role. | |||
| Status DatasetOp::PrepareNodePreAction() { | |||
| if (BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepRepeat)) set_control_flag(kDeOpRepeated); | |||
| return Status::OK(); | |||
| } | |||
| Status DatasetOp::PrepareNodePreAction() { return Status::OK(); } | |||
| // During tree prepare phase, operators may have specific post-operations to perform depending on | |||
| // their role. | |||
| Status DatasetOp::PrepareNodePostAction() { | |||
| // If this op does not have any children and it is in a repeat path of the tree... | |||
| if (child_.empty() && BitTest(op_ctrl_flags_, kDeOpRepeated)) { | |||
| // push ourselves onto the eoe operator stack. Later, a repeat/epoch ctrl operator | |||
| // above us will consume them. | |||
| tree_->AddToEOEOpStack(shared_from_this()); | |||
| } | |||
| // Creating Connector object for each op. | |||
| // The consumer of the root node is assumed to be one thread. | |||
| // If multiple threads are consuming from the root node, they will get the ordered data in round robin fashion. | |||
| @@ -346,34 +360,13 @@ Status DatasetOp::Accept(NodePass *p, bool *modified) { | |||
| return p->RunOnNode(shared_from_this(), modified); | |||
| } | |||
| // A helper function with some common code that leaf nodes can use during | |||
| // prepare phase for checking if they need to assign a sampler to the cache. | |||
| Status DatasetOp::SaveSamplerForCache(bool random_access_op) { | |||
| // If we are a descendant under a cache op and we have a sampler, then save this sampler | |||
| // to a stack so that the cache can pick it up during it's processing above us. | |||
| if (sampler_) { | |||
| if (BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepCache)) { | |||
| // use move semantic to set our sampler_ to null after the move. This is okay because a sampler is | |||
| // useless to a random data op. It was only being used as a temporary holding until the cache can | |||
| // be created | |||
| tree_->AddToSamplerStack(sampler_); | |||
| MS_LOG(INFO) << "Preparing a leaf op: passing sampler up the tree for Cache handling."; | |||
| } else if (!random_access_op) { | |||
| // A sampler exists, but we are not in a caching tree and we are not a random access mappable leaf. | |||
| // This is an error because that type of leaf does not use sampling unless there's a cache to hook it into. | |||
| RETURN_STATUS_UNEXPECTED( | |||
| "Non-mappable leaf op has a sampler, but it only supports sampling if there is a cache after it in the tree"); | |||
| } | |||
| } | |||
| if (!random_access_op) { | |||
| // Since we don't truly need the sampler for this non-mappable dataset and it's been saved for the cache | |||
| // we can remove it now from the base. | |||
| sampler_.reset(); | |||
| } | |||
| // Getter for the sampler, and it also removes the sampler from the op | |||
| Status DatasetOp::FetchRemoveSampler(std::shared_ptr<Sampler> *sampler) { | |||
| *sampler = sampler_; // It's okay if it sampler_ points to nullptr | |||
| sampler_.reset(); // clear our member-copy of this pointer. We no longer have this sampler | |||
| return Status::OK(); | |||
| } | |||
| uint32_t DatasetOp::GenerateCRC(const std::shared_ptr<DatasetOp> &op) { | |||
| std::stringstream ss; | |||
| op->tree_->Print(ss, op); | |||
| @@ -45,10 +45,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||
| public: | |||
| static constexpr int32_t kInvalidOperatorId = -1; | |||
| // Flags that control operator runtime behaviours | |||
| // Operator control flags | |||
| enum OpControlFlags { | |||
| kDeOpNone = 0, | |||
| kDeOpRepeated = 1, // Operator is a leaf node in a repeat path | |||
| kDeOpRepeated = 1, // Operator is a node in a repeat path | |||
| kDeOpLastRepeat = 1 << 1 // We are in the last repeat loop | |||
| }; | |||
| @@ -71,17 +71,23 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||
| /// \param child - shared pointer to the child to remove. | |||
| Status RemoveChild(std::shared_ptr<DatasetOp> child); | |||
| /// \brief Removes this node from the tree and connects it's parent/child together. | |||
| /// \brief Removes this node from the tree and connects it's parent/child together | |||
| /// \return Status eerror code returned | |||
| Status Remove(); | |||
| /// \brief Getter function to get a shared pointer to our child | |||
| /// \param child_index - An operator can have n children. Indicates choose which child to return. | |||
| /// \param[in] child_index An operator can have n children. Indicates which child to return. | |||
| /// \return The shared pointer to the child. If there are no children, it returns null regardless of the given index | |||
| std::shared_ptr<DatasetOp> child(int32_t child_index) const; | |||
| /// \brief Inserts a operator as the parent current op. | |||
| /// Inserted op will become the sole parent of the current op. | |||
| /// The existing parent of the current op will be transferred to the inserted op. | |||
| /// \brief Getter function to get the pointer to our parent | |||
| /// If there are no parents, it returns null regardless of the given index | |||
| /// \param[in] parent_index An operator can have n parents. Indicates which parent to return. | |||
| void Parent(DatasetOp **parent, int32_t parent_index) const; | |||
| // Inserts a operator as the parent current op. | |||
| // Inserted op will become the sole parent of the current op. | |||
| // The existing parent of the current op will be transferred to the inserted op. | |||
| Status InsertAsParent(std::shared_ptr<DatasetOp> to_add); | |||
| /// \brief Creates the connector within this operator | |||
| @@ -161,16 +167,6 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||
| /// \return Status - The error code return | |||
| virtual Status Reset(); | |||
| /// \brief This calls the reset function on this subtree in pre-order | |||
| /// \return Status - The error code return | |||
| virtual Status ResetSubtree() { | |||
| RETURN_IF_NOT_OK(Reset()); | |||
| for (const auto &c : child_) { | |||
| RETURN_IF_NOT_OK(c->ResetSubtree()); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| /// \brief During tree prepare phase, operators may have specific pre-operations to perform depending on | |||
| /// their role. | |||
| /// \notes Derived versions of this function should always call it's superclass version first | |||
| @@ -296,7 +292,12 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||
| /// \return Shared pointer to the sampler (may return nullptr) | |||
| std::shared_ptr<Sampler> sampler() { return sampler_; } | |||
| /// Computes a CRC value for the operator | |||
| /// \brief Getter for the sampler, and it also removes the sampler from the op | |||
| /// \param[out] sampler A pointer to the output sampler that was removed | |||
| /// \return Status error code | |||
| Status FetchRemoveSampler(std::shared_ptr<Sampler> *sampler); | |||
| // Computes a CRC value for the operator | |||
| static uint32_t GenerateCRC(const std::shared_ptr<DatasetOp> &op); | |||
| /// \brief A helper templated function for casting "this" pointer to shared_ptr<derived> | |||
| @@ -307,17 +308,24 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||
| return std::static_pointer_cast<Derived>(shared_from_this()); | |||
| } | |||
| protected: | |||
| /// Adds a parent operator to this operator | |||
| /// \notes External callers do not have access to this function. | |||
| /// \param parent - The parent node to add | |||
| void AddParent(DatasetOp *parent); | |||
| /// \brief Setter for the sampler. Allows you to overwrite a previous sampler with a new one. | |||
| void SetSampler(std::shared_ptr<Sampler> sampler) { sampler_ = sampler; } | |||
| /// \brief Checks if this is a leaf node (0 children) | |||
| /// \return boolean returns true if it's a leaf | |||
| bool IsLeaf() { return (child_.empty()); } | |||
| /// Removes a parent operator from this operator | |||
| /// \notes External callers do not have access to this function. | |||
| /// \param parent - The parent node to remove | |||
| protected: | |||
| /// \brief Removes a parent operator from this operator | |||
| /// \notes External callers do not have access to this function | |||
| /// \param[in] parent The parent node to remove | |||
| void RemoveParent(const DatasetOp *parent); | |||
| /// \brief Adds a parent operator to this operator | |||
| /// \notes External callers do not have access to this function | |||
| /// \param[in] parent The parent node to add | |||
| void AddParent(DatasetOp *parent); | |||
| /// Compute the current op's column map using its child's column map. | |||
| /// Get called during the tree post-prepare phase in PrepareNodePostAction. | |||
| /// This base implementation just inherits the map from child 0, and can only be used if the number of children is 1. | |||
| @@ -325,12 +333,6 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||
| /// \return - Status | |||
| virtual Status ComputeColMap(); | |||
| /// A helper function with some common code that leaf nodes can use during | |||
| /// pre/pare phase for checking if they need to assign a sampler to the cache. | |||
| /// \param random_access_op - indicate if this is a mappable random access leaf or not | |||
| /// \return - Status | |||
| Status SaveSamplerForCache(bool random_access_op); | |||
| std::vector<std::shared_ptr<DatasetOp>> child_; // Child nodes | |||
| std::vector<DatasetOp *> parent_; // Parent nodes. No ownership | |||
| std::shared_ptr<Sampler> sampler_; // Some leaf ops might have a sampler | |||
| @@ -77,26 +77,6 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const { | |||
| } | |||
| } | |||
| // Base-class override for executing specific RepeatOp configurations. This code will be called | |||
| // during the execution tree prepare phase when it is visiting this operator. | |||
| Status RepeatOp::PrepareNodePostAction() { | |||
| // Run any common code from super class first before adding our own specific logic | |||
| RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction()); | |||
| std::shared_ptr<DatasetOp> leaf_op = tree_->PopFromEOEOpStack(); | |||
| while (leaf_op != nullptr) { | |||
| // Track the leaf operators that are under this repeat op. | |||
| eoe_ops_.push_back(leaf_op); | |||
| leaf_op = tree_->PopFromEOEOpStack(); | |||
| } | |||
| // Push ourselves to the stack in case one of our ascendants is repeat too. | |||
| tree_->AddToEOEOpStack(shared_from_this()); | |||
| return Status::OK(); | |||
| } | |||
| // Base-class override for setting specific RepeatOp configurations. This code will be called | |||
| // during the execution tree prepare phase BEFORE traversing down to child operators. | |||
| uint32_t RepeatOp::PrepareFlags() const { return ExecutionTree::kDePrepRepeat; } | |||
| // This function returns the buffer that is at the top of our output connector. The caller is | |||
| // typically our parent node, when the parent is asking us to provide the next buffer of data. | |||
| // Since RepeatOp is an inlined op, getting a buffer from us will simply bounce you to get | |||
| @@ -130,7 +110,8 @@ Status RepeatOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t wo | |||
| // Base-class override for handling cases when an eoe is received. | |||
| Status RepeatOp::EoeReceived(int32_t worker_id) { | |||
| repeat_count_++; | |||
| MS_LOG(DEBUG) << "Repeat operator end of epoch message received. Repeat count is now: " << repeat_count_ << "."; | |||
| MS_LOG(DEBUG) << "Repeat operator (" << operator_id_ | |||
| << ") end of epoch message received. Repeat count is now: " << repeat_count_ << "."; | |||
| bool repeated = BitTest(op_ctrl_flags_, kDeOpRepeated); | |||
| bool last_repeat = BitTest(op_ctrl_flags_, kDeOpLastRepeat); | |||
| // If we've reached the requested repeat count, then flag the eoe nodes | |||
| @@ -149,8 +130,12 @@ Status RepeatOp::EoeReceived(int32_t worker_id) { | |||
| return Status::OK(); | |||
| } | |||
| // base-class ResetSubtree | |||
| return (DatasetOp::ResetSubtree()); | |||
| // Invoke a reset against the eoe nodes only. | |||
| for (auto &eoe_op : eoe_ops_) { | |||
| RETURN_IF_NOT_OK(eoe_op->Reset()); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Class functor operator () override. | |||
| @@ -178,6 +163,18 @@ int32_t RepeatOp::num_consumers() const { | |||
| } | |||
| } | |||
| // Drive reset actions if needed | |||
| Status RepeatOp::Reset() { | |||
| // If there's nested repeats, an ascendant repeat may have ourself listed as an eoe op. | |||
| // In that case, we now have to bounce the reset down to our own eoe ops. | |||
| MS_LOG(DEBUG) << "Repeat operator (" << operator_id_ << ") reset."; | |||
| for (auto &eoe_op : eoe_ops_) { | |||
| RETURN_IF_NOT_OK(eoe_op->Reset()); | |||
| } | |||
| state_ = OpState::kDeOpRunning; | |||
| return Status::OK(); | |||
| } | |||
| int32_t RepeatOp::num_producers() const { | |||
| if (child_.empty() || child_[0] == nullptr) { | |||
| MS_LOG(DEBUG) << "Repeat operator, pointer to child node is null. Returning 0."; | |||
| @@ -187,6 +184,12 @@ int32_t RepeatOp::num_producers() const { | |||
| } | |||
| } | |||
| // Pre-Visitor accept method for NodePass | |||
| Status RepeatOp::PreAccept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call the pre-visitation | |||
| return p->PreRunOnNode(shared_from_base<RepeatOp>(), modified); | |||
| } | |||
| // Visitor accept method for NodePass | |||
| Status RepeatOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| @@ -18,6 +18,7 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "dataset/engine/datasetops/pipeline_op.h" | |||
| @@ -82,14 +83,6 @@ class RepeatOp : public PipelineOp { | |||
| // @return Status - The error code return | |||
| Status operator()() override; | |||
| // Base-class override for setting specific RepeatOp configurations. This code will be called | |||
| // during the execution tree prepare phase BEFORE traversing down to child operators. | |||
| uint32_t PrepareFlags() const override; | |||
| // Base-class override for executing specific RepeatOp configurations. This code will be called | |||
| // during the execution tree post-prepare phase when it is visiting this operator. | |||
| Status PrepareNodePostAction() override; | |||
| // This function returns the buffer that is at the top of our output connector. The caller is | |||
| // typically our parent node, when the parent is asking us to provide the next buffer of data. | |||
| // Since RepeatOp is an inlined op, getting a buffer from us will simply bounce you to get | |||
| @@ -110,6 +103,10 @@ class RepeatOp : public PipelineOp { | |||
| // @param worker_id - The worker id | |||
| Status EofReceived(int32_t worker_id) override; | |||
| /// \brief reset Op | |||
| /// \@return Status - The error code return | |||
| Status Reset() override; | |||
| // Base-class override. Return the number of workers in the first parent. | |||
| // @param workerId - The worker id | |||
| int32_t num_consumers() const override; | |||
| @@ -118,16 +115,26 @@ class RepeatOp : public PipelineOp { | |||
| // @param workerId - The worker id | |||
| int32_t num_producers() const override; | |||
| // Base-class override for NodePass visitor acceptor. | |||
| // @param p - Pointer to the NodePass to be accepted. | |||
| // @param modified - Whether this node visit modified the pipeline. | |||
| // @return - Status of the node visit. | |||
| /// \brief Base-class override for NodePass pre-visit acceptor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status PreAccept(NodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for NodePass visitor acceptor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| // Op name getter | |||
| // @return Name of the current Op | |||
| std::string Name() const override { return "RepeatOp"; } | |||
| /// \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes | |||
| /// \param[in] eoe_op The input leaf/eoe operator to add to the list | |||
| void AddToEoeList(std::shared_ptr<DatasetOp> eoe_op) { eoe_ops_.push_back(std::move(eoe_op)); } | |||
| private: | |||
| int32_t max_repeats_; // The number of repeats that the user requested | |||
| int32_t repeat_count_; // A counter for the current number of executed repeats | |||
| @@ -22,6 +22,7 @@ | |||
| #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" | |||
| #include "dataset/engine/data_schema.h" | |||
| #include "dataset/engine/execution_tree.h" | |||
| #include "dataset/engine/opt/pass.h" | |||
| #include "dataset/kernels/image/image_utils.h" | |||
| namespace mindspore { | |||
| @@ -408,6 +409,12 @@ Status CelebAOp::Reset() { | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accept method for NodePass | |||
| Status CelebAOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(shared_from_base<CelebAOp>(), modified); | |||
| } | |||
| Status CelebAOp::ComputeColMap() { | |||
| // Set the column name map (base class field) | |||
| if (column_name_id_map_.empty()) { | |||
| @@ -169,6 +169,12 @@ class CelebAOp : public ParallelOp, RandomAccessOp { | |||
| // @return Status - The error code return | |||
| Status AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer); | |||
| /// \brief Base-class override for NodePass visitor acceptor | |||
| /// \param[in] p Pointer to the NodePass to be accepted | |||
| /// \param[out] modified Indicator if the node was changed at all | |||
| /// \return Status of the node visit | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| // Op name getter | |||
| // @return Name of the current Op | |||
| std::string Name() const { return "CelebAOp"; } | |||
| @@ -26,6 +26,7 @@ | |||
| #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" | |||
| #include "dataset/engine/db_connector.h" | |||
| #include "dataset/engine/execution_tree.h" | |||
| #include "dataset/engine/opt/pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -450,6 +451,12 @@ Status CifarOp::CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t * | |||
| } | |||
| } | |||
| // Visitor accept method for NodePass | |||
| Status CifarOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(shared_from_base<CifarOp>(), modified); | |||
| } | |||
| Status CifarOp::ComputeColMap() { | |||
| // set the column name map (base class field) | |||
| if (column_name_id_map_.empty()) { | |||
| @@ -155,6 +155,12 @@ class CifarOp : public ParallelOp, public RandomAccessOp { | |||
| // @return | |||
| static Status CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *count); | |||
| /// \brief Base-class override for NodePass visitor acceptor | |||
| /// \param[in] p Pointer to the NodePass to be accepted | |||
| /// \param[out] modified Indicator if the node was changed at all | |||
| /// \return Status of the node visit | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| // Op name getter | |||
| // @return Name of the current Op | |||
| std::string Name() const override { return "CifarOp"; } | |||
| @@ -24,6 +24,7 @@ | |||
| #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" | |||
| #include "dataset/engine/db_connector.h" | |||
| #include "dataset/engine/execution_tree.h" | |||
| #include "dataset/engine/opt/pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -624,6 +625,12 @@ Status CocoOp::GetClassIndexing(const std::string &dir, const std::string &file, | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accept method for NodePass | |||
| Status CocoOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(shared_from_base<CocoOp>(), modified); | |||
| } | |||
| Status CocoOp::ComputeColMap() { | |||
| // Set the column name map (base class field) | |||
| if (column_name_id_map_.empty()) { | |||
| @@ -200,6 +200,12 @@ class CocoOp : public ParallelOp, public RandomAccessOp { | |||
| static Status GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode, | |||
| std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing); | |||
| /// \brief Base-class override for NodePass visitor acceptor | |||
| /// \param[in] p Pointer to the NodePass to be accepted | |||
| /// \param[out] modified Indicator if the node was changed at all | |||
| /// \return Status of the node visit | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| private: | |||
| // Initialize Sampler, calls sampler->Init() within | |||
| // @return Status - The error code return | |||
| @@ -26,6 +26,7 @@ | |||
| #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" | |||
| #include "dataset/engine/db_connector.h" | |||
| #include "dataset/engine/execution_tree.h" | |||
| #include "dataset/engine/opt/pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -416,6 +417,12 @@ Status ManifestOp::GetClassIndexing(const std::string &file, const py::dict &dic | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accept method for NodePass | |||
| Status ManifestOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(shared_from_base<ManifestOp>(), modified); | |||
| } | |||
| Status ManifestOp::ComputeColMap() { | |||
| // Set the column name map (base class field) | |||
| if (column_name_id_map_.empty()) { | |||
| @@ -172,6 +172,12 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { | |||
| static Status GetClassIndexing(const std::string &file, const py::dict &dict, const std::string &usage, | |||
| std::map<std::string, int32_t> *output_class_indexing); | |||
| /// \brief Base-class override for NodePass visitor acceptor | |||
| /// \param[in] p Pointer to the NodePass to be accepted | |||
| /// \param[out] modified Indicator if the node was changed at all | |||
| /// \return Status of the node visit | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| // Op name getter | |||
| // @return Name of the current Op | |||
| std::string Name() const override { return "ManifestOp"; } | |||
| @@ -23,6 +23,7 @@ | |||
| #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" | |||
| #include "dataset/engine/db_connector.h" | |||
| #include "dataset/engine/execution_tree.h" | |||
| #include "dataset/engine/opt/pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -428,6 +429,12 @@ Status MnistOp::CountTotalRows(const std::string &dir, int64_t *count) { | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accept method for NodePass | |||
| Status MnistOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(shared_from_base<MnistOp>(), modified); | |||
| } | |||
| Status MnistOp::ComputeColMap() { | |||
| // set the column name map (base class field) | |||
| if (column_name_id_map_.empty()) { | |||
| @@ -152,6 +152,12 @@ class MnistOp : public ParallelOp, public RandomAccessOp { | |||
| // @return | |||
| static Status CountTotalRows(const std::string &dir, int64_t *count); | |||
| /// \brief Base-class override for NodePass visitor acceptor | |||
| /// \param[in] p Pointer to the NodePass to be accepted | |||
| /// \param[out] modified Indicator if the node was changed at all | |||
| /// \return Status of the node visit | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| // Op name getter | |||
| // @return Name of the current Op | |||
| std::string Name() const override { return "MnistOp"; } | |||
| @@ -22,6 +22,7 @@ | |||
| #include "dataset/util/random.h" | |||
| #include "dataset/util/wait_post.h" | |||
| #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" | |||
| #include "dataset/engine/opt/pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -406,6 +407,12 @@ Status RandomDataOp::Reset() { | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accept method for NodePass | |||
| Status RandomDataOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(shared_from_base<RandomDataOp>(), modified); | |||
| } | |||
| Status RandomDataOp::ComputeColMap() { | |||
| // Extract the column name mapping from the schema and save it in the class. | |||
| if (column_name_id_map_.empty()) { | |||
| @@ -415,15 +422,5 @@ Status RandomDataOp::ComputeColMap() { | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // During tree prepare phase, operators may have specific post-operations to perform depending on | |||
| // their role. | |||
| Status RandomDataOp::PrepareNodePostAction() { | |||
| // Run common code from super class before adding RandomDataOp specific handling | |||
| RETURN_IF_NOT_OK(ParallelOp::PrepareNodePostAction()); | |||
| // Specific handling for this op, we need to do cache op work to assign the sampler to the cache. | |||
| RETURN_IF_NOT_OK(DatasetOp::SaveSamplerForCache(false)); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -203,12 +203,6 @@ class RandomDataOp : public ParallelOp { | |||
| // @return Name of the current Op | |||
| std::string Name() const override { return "RandomDataOp"; } | |||
| // During tree prepare phase, operators may have specific post-operations to perform depending on | |||
| // their role. | |||
| // @notes Derived versions of this function should always call it's superclass version first | |||
| // before providing their own implementations. | |||
| Status PrepareNodePostAction() override; | |||
| private: | |||
| /** | |||
| * The entry point code for when workers are launched | |||
| @@ -266,6 +260,12 @@ class RandomDataOp : public ParallelOp { | |||
| return ++buffer_id_; | |||
| } | |||
| // Base-class override for NodePass visitor acceptor. | |||
| // @param p - Pointer to the NodePass to be accepted. | |||
| // @param modified - Whether this node visit modified the pipeline. | |||
| // @return - Status of the node visit. | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| // Private function for computing the assignment of the column name map. | |||
| // @return - Status | |||
| Status ComputeColMap() override; | |||
| @@ -1019,31 +1019,28 @@ Status TFReaderOp::ComputeColMap() { | |||
| return Status::OK(); | |||
| } | |||
| // Brief If a cache has been added into the ascendant tree over this tf reader, then the cache will be executing | |||
| // a sampler for fetching the data. As such, any options in the tf reader need to be reset to its defaults so | |||
| // that this tf reader will produce the full set of data into the cache. | |||
| void TFReaderOp::MakeSimpleProducer() { | |||
| device_id_ = 0; | |||
| num_devices_ = 1; | |||
| total_rows_ = 0; | |||
| shuffle_files_ = false; | |||
| equal_rows_per_shard_ = false; | |||
| } | |||
| // During tree prepare phase, operators may have specific post-operations to perform depending on | |||
| // their role. | |||
| Status TFReaderOp::PrepareNodePostAction() { | |||
| // Run common code from super class before adding TFReaderOp specific handling | |||
| RETURN_IF_NOT_OK(ParallelOp::PrepareNodePostAction()); | |||
| // Specific handling for this op, we need to do cache op work so assign the sampler to the cache | |||
| // TF is a special case because it can support file-based sharding/shuffling, or, if there | |||
| // is a cache, then it can also do row-based sampler using the sampler on the cache. | |||
| // Thus, pass true for random access op flag when saving the sampler. This is a special case, | |||
| // since usually a non-mappable dataset would pass false here. | |||
| RETURN_IF_NOT_OK(DatasetOp::SaveSamplerForCache(true)); | |||
| // Now that the sampler has been saved for the cache, we need to adjust the TFReaderOp to turn it into | |||
| // a simpler producer of all data (no shuffling or sharding or anything) | |||
| if (BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepCache)) { | |||
| device_id_ = 0; | |||
| num_devices_ = 1; | |||
| total_rows_ = 0; | |||
| shuffle_files_ = false; | |||
| equal_rows_per_shard_ = false; | |||
| sampler_.reset(); // Normally SaveSampler code did this for us, but we passed in true above (See comment) | |||
| } else { | |||
| if (!BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepCache)) { | |||
| // This sanity check had been delayed until now in the prepare loop. | |||
| // If we are not in a cache path, then we can validate the the file-based sharding config. | |||
| // If we are not in a cache path, then we can validate the file-based sharding config. | |||
| // If we are in a cache path, there is no file-based sharding so the check is not correct in that | |||
| // situation. | |||
| if (!equal_rows_per_shard_ && dataset_files_list_.size() < static_cast<uint32_t>(num_devices_)) { | |||
| @@ -246,6 +246,11 @@ class TFReaderOp : public ParallelOp { | |||
| // @return Vector of the input file names | |||
| std::vector<std::string> FileNames() { return dataset_files_list_; } | |||
| /// \Brief If a cache has been added into the ascendant tree over this tf reader, then the cache will be executing | |||
| /// a sampler for fetching the data. As such, any options in the tf reader need to be reset to its defaults so | |||
| /// that this tf reader will produce the full set of data into the cache. | |||
| void MakeSimpleProducer(); | |||
| // During tree prepare phase, operators may have specific post-operations to perform depending on | |||
| // their role. | |||
| // @notes Derived versions of this function should always call it's superclass version first | |||
| @@ -25,6 +25,7 @@ | |||
| #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" | |||
| #include "dataset/engine/db_connector.h" | |||
| #include "dataset/engine/execution_tree.h" | |||
| #include "dataset/engine/opt/pass.h" | |||
| using tinyxml2::XMLDocument; | |||
| using tinyxml2::XMLElement; | |||
| @@ -449,6 +450,11 @@ Status VOCOp::GetClassIndexing(const std::string &dir, const std::string &task_t | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accept method for NodePass | |||
| Status VOCOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(shared_from_base<VOCOp>(), modified); | |||
| } | |||
| Status VOCOp::ComputeColMap() { | |||
| // Set the column name map (base class field) | |||
| @@ -205,6 +205,12 @@ class VOCOp : public ParallelOp, public RandomAccessOp { | |||
| static Status GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode, | |||
| const py::dict &dict, std::map<std::string, int32_t> *output_class_indexing); | |||
| /// \brief Base-class override for NodePass visitor acceptor | |||
| /// \param[in] p Pointer to the NodePass to be accepted | |||
| /// \param[out] modified Indicator if the node was changed at all | |||
| /// \return Status of the node visit | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| // Op name getter | |||
| // @return Name of the current Op | |||
| std::string Name() const override { return "VOCOp"; } | |||
| @@ -127,12 +127,6 @@ Status TakeOp::FillBuffer(std::unique_ptr<DataBuffer> *buffer, std::unique_ptr<D | |||
| return Status::OK(); | |||
| } | |||
| Status TakeOp::PrepareNodePostAction() { | |||
| RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction()); | |||
| tree_->AddToEOEOpStack(shared_from_this()); | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accept method for NodePass | |||
| Status TakeOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| @@ -78,12 +78,6 @@ class TakeOp : public PipelineOp { | |||
| // @return Status - The error code return | |||
| Status operator()() override; | |||
| // During tree prepare phase, operators may have specific post-operations to perform depending on | |||
| // their role. | |||
| // @notes Derived versions of this function should always call it's superclass version first | |||
| // before providing their own implementations. | |||
| Status PrepareNodePostAction() override; | |||
| // Base-class override for NodePass visitor acceptor. | |||
| // @param p - Pointer to the NodePass to be accepted. | |||
| // @param modified - Whether this node visit modified the pipeline. | |||
| @@ -21,6 +21,8 @@ | |||
| #include "dataset/util/task_manager.h" | |||
| #include "dataset/engine/opt/pass.h" | |||
| #include "dataset/engine/opt/pre/removal_pass.h" | |||
| #include "dataset/engine/opt/pre/cache_transform_pass.h" | |||
| #include "dataset/engine/opt/post/repeat_pass.h" | |||
| #include "dataset/engine/perf/profiling.h" | |||
| #include "dataset/engine/perf/monitor.h" | |||
| @@ -215,18 +217,33 @@ Status ExecutionTree::PrepareTreePreAction() { | |||
| bool modified = false; | |||
| std::vector<std::unique_ptr<Pass>> pre_actions; | |||
| // Construct pre actions | |||
| MS_LOG(INFO) << "Running pre pass"; | |||
| pre_actions.push_back(std::make_unique<RemovalPass>(RemovalPass())); | |||
| MS_LOG(INFO) << "Running pre pass loops."; | |||
| pre_actions.push_back(std::make_unique<RemovalPass>()); | |||
| pre_actions.push_back(std::make_unique<CacheTransformPass>()); | |||
| // Apply pre action passes | |||
| for (auto &pass : pre_actions) { | |||
| RETURN_IF_NOT_OK(pass->Run(this, &modified)); | |||
| } | |||
| MS_LOG(INFO) << "Pre passes complete."; | |||
| return Status::OK(); | |||
| } | |||
| Status ExecutionTree::PrepareTreePostAction() { | |||
| // The tree is ready to be prepared. | |||
| tree_state_ = kDeTStatePrepare; | |||
| bool modified = false; | |||
| std::vector<std::unique_ptr<Pass>> post_actions; | |||
| // Construct pre actions | |||
| MS_LOG(INFO) << "Running post pass loops."; | |||
| post_actions.push_back(std::make_unique<RepeatPass>()); | |||
| // Apply post action passes | |||
| for (auto &pass : post_actions) { | |||
| RETURN_IF_NOT_OK(pass->Run(this, &modified)); | |||
| } | |||
| MS_LOG(INFO) << "Post passes complete."; | |||
| return Status::OK(); | |||
| } | |||
| @@ -280,31 +297,5 @@ Status ExecutionTree::PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op) | |||
| return Status::OK(); | |||
| } | |||
| // Adds an operator to the eoe operator stack during prepare phase. | |||
| void ExecutionTree::AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op) { eoe_stack_.push(dataset_op); } | |||
| // Pops an operator from the eoe operator stack during prepare phase. | |||
| std::shared_ptr<DatasetOp> ExecutionTree::PopFromEOEOpStack() { | |||
| std::shared_ptr<DatasetOp> top_op = nullptr; | |||
| if (!eoe_stack_.empty()) { | |||
| top_op = eoe_stack_.top(); | |||
| eoe_stack_.pop(); | |||
| } | |||
| return top_op; | |||
| } | |||
| // Adds a sampler to the sampler stack during prepare phase. | |||
| void ExecutionTree::AddToSamplerStack(std::shared_ptr<Sampler> sampler) { sampler_stack_.push(sampler); } | |||
| // Pops an operator from the sampler stack during prepare phase. | |||
| std::shared_ptr<Sampler> ExecutionTree::PopFromSamplerStack() { | |||
| std::shared_ptr<Sampler> top_sampler = nullptr; | |||
| if (!sampler_stack_.empty()) { | |||
| top_sampler = sampler_stack_.top(); | |||
| sampler_stack_.pop(); | |||
| } | |||
| return top_sampler; | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -200,24 +200,6 @@ class ExecutionTree { | |||
| // @return Status - The error code return | |||
| Status PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op); | |||
| /// Adds an operator to the eoe operator stack during prepare phase. | |||
| /// \param op - The dataset op to work add to eoe stack | |||
| /// \return Status - The error code return | |||
| void AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op); | |||
| /// Pops an operator from the eoe operator stack during prepare phase. | |||
| /// \return shared_ptr to the popped operator | |||
| std::shared_ptr<DatasetOp> PopFromEOEOpStack(); | |||
| /// Adds a sampler to the sampler stack during prepare phase. | |||
| /// \param samplerop - The dataset op to work add to eoe stack | |||
| /// \return Status - The error code return | |||
| void AddToSamplerStack(std::shared_ptr<Sampler> sampler); | |||
| /// Pops an operator from the sampler stack during prepare phase. | |||
| /// \return shared_ptr to the popped operator | |||
| std::shared_ptr<Sampler> PopFromSamplerStack(); | |||
| // Return the pointer to the TaskGroup | |||
| // @return raw pointer to the TaskGroup | |||
| TaskGroup *AllTasks() const { return tg_.get(); } | |||
| @@ -248,8 +230,6 @@ class ExecutionTree { | |||
| TreeState tree_state_; // Tracking the current tree state | |||
| std::unique_ptr<Monitor> perf_monitor_; // Performance Monitor | |||
| std::unique_ptr<ProfilingManager> profiling_manager_; // Profiling manager | |||
| std::stack<std::shared_ptr<DatasetOp>> eoe_stack_; // A stack used during prepare phase | |||
| std::stack<std::shared_ptr<Sampler>> sampler_stack_; // A stack used during prepare phase | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -2,6 +2,9 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc" | |||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | |||
| add_library(engine-opt OBJECT | |||
| pass.cc | |||
| post/repeat_pass.cc | |||
| pre/cache_pass.cc | |||
| pre/cache_transform_pass.cc | |||
| pre/removal_nodes.cc | |||
| pre/removal_pass.cc | |||
| util/printer_pass.cc | |||
| @@ -16,6 +16,9 @@ | |||
| #include "dataset/engine/opt/pass.h" | |||
| #include "dataset/engine/datasetops/batch_op.h" | |||
| #include "dataset/engine/datasetops/cache_op.h" | |||
| #include "dataset/engine/datasetops/cache_merge_op.h" | |||
| #include "dataset/engine/datasetops/cache_lookup_op.h" | |||
| #include "dataset/engine/datasetops/dataset_op.h" | |||
| #include "dataset/engine/datasetops/device_queue_op.h" | |||
| #include "dataset/engine/datasetops/map_op.h" | |||
| @@ -24,8 +27,15 @@ | |||
| #include "dataset/engine/datasetops/repeat_op.h" | |||
| #include "dataset/engine/datasetops/skip_op.h" | |||
| #include "dataset/engine/datasetops/shuffle_op.h" | |||
| #include "dataset/engine/datasetops/source/celeba_op.h" | |||
| #include "dataset/engine/datasetops/source/cifar_op.h" | |||
| #include "dataset/engine/datasetops/source/coco_op.h" | |||
| #include "dataset/engine/datasetops/source/manifest_op.h" | |||
| #include "dataset/engine/datasetops/source/mindrecord_op.h" | |||
| #include "dataset/engine/datasetops/source/mnist_op.h" | |||
| #include "dataset/engine/datasetops/source/random_data_op.h" | |||
| #include "dataset/engine/datasetops/source/tf_reader_op.h" | |||
| #include "dataset/engine/datasetops/source/voc_op.h" | |||
| #ifdef ENABLE_PYTHON | |||
| #include "dataset/engine/datasetops/filter_op.h" | |||
| #include "dataset/engine/datasetops/source/generator_op.h" | |||
| @@ -145,6 +155,11 @@ Status NodePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) { | |||
| } | |||
| #endif | |||
| Status NodePass::RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| @@ -164,5 +179,70 @@ Status NodePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -47,6 +47,10 @@ class FilterOp; | |||
| class GeneratorOp; | |||
| #endif | |||
| class RandomDataOp; | |||
| class RepeatOp; | |||
| class TakeOp; | |||
| class ZipOp; | |||
| @@ -55,6 +59,24 @@ class DeviceQueueOp; | |||
| class ImageFolderOp; | |||
| class CacheOp; | |||
| class MnistOp; | |||
| class ManifestOp; | |||
| class CifarOp; | |||
| class VOCOp; | |||
| class CocoOp; | |||
| class CelebAOp; | |||
| class CacheMergeOp; | |||
| class CacheLookupOp; | |||
| // The base class Pass is the basic unit of tree transformation. | |||
| // The actual implementation of the passes will be derived from here. | |||
| class Pass : public std::enable_shared_from_this<Pass> { | |||
| @@ -138,14 +160,42 @@ class NodePass : public Pass { | |||
| virtual Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified); | |||
| #endif | |||
| virtual Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<ZipOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified); | |||
| virtual Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified); | |||
| virtual Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified); | |||
| virtual Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified); | |||
| private: | |||
| // Helper function to perform DFS visit | |||
| Status DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified); | |||
| @@ -0,0 +1,161 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <memory> | |||
| #include "dataset/engine/opt/post/repeat_pass.h" | |||
| #include "dataset/engine/datasetops/repeat_op.h" | |||
| #include "dataset/engine/datasetops/cache_op.h" | |||
| #include "dataset/engine/datasetops/cache_lookup_op.h" | |||
| #include "dataset/engine/datasetops/cache_merge_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| RepeatPass::RepeatPass() : is_repeated_(false), nested_repeats_(0), is_merge_(false), cache_lookup_(nullptr) {} | |||
| // Identifies the subtree below this node as being in a repeated path of the tree. | |||
| Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | |||
| // If we are already repeated, then this is a nested repeat. | |||
| if (is_repeated_) { | |||
| nested_repeats_++; | |||
| } | |||
| is_repeated_ = true; | |||
| return Status::OK(); | |||
| } | |||
| // Identifies the subtree below this node as being in a cache merge path | |||
| Status RepeatPass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) { | |||
| // Turn on the flag that we're under a merge op | |||
| is_merge_ = true; | |||
| return Status::OK(); | |||
| } | |||
| // Hooks up any identified eoe nodes under this repeat. | |||
| Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | |||
| // Pop the leaf ops from the save-area stack and add them to the repeat op's eoe node tracking | |||
| std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack(); | |||
| while (leaf_op != nullptr) { | |||
| node->AddToEoeList(leaf_op); | |||
| leaf_op = PopFromEOEOpStack(); | |||
| } | |||
| // We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up | |||
| // and add it to the list of eoe/leaf ops for the repeat, removing it from the save area. | |||
| if (is_merge_ && cache_lookup_) { | |||
| cache_lookup_->set_control_flag(DatasetOp::kDeOpRepeated); | |||
| node->AddToEoeList(std::move(cache_lookup_)); | |||
| } | |||
| // If we are a nested repeat, then we add ourself to the repeat stack for the next one above us. | |||
| // A nested repeat acts like an eoe/leaf for the repeat in the ascendant tree. | |||
| if (nested_repeats_ > 0) { | |||
| node->set_control_flag(DatasetOp::kDeOpRepeated); | |||
| AddToEOEOpStack(node); | |||
| nested_repeats_--; | |||
| } | |||
| // If we are not nested, or we were the top-most repeat, now we clear the flag | |||
| if (nested_repeats_ == 0) { | |||
| is_repeated_ = false; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // CacheOp removes previous leaf ops and replaces them with itself | |||
| Status RepeatPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||
| if (is_repeated_) { | |||
| node->set_control_flag(DatasetOp::kDeOpRepeated); | |||
| // if we are a cache within a repeat path of the tree, then there will be | |||
| // eoe-generating ops in the eoe op stack in the tree. They are flagged as such so that the | |||
| // repeat or epoch ctrl operators can work with them for repeat activity during runtime. | |||
| // However, since a cache is present: | |||
| // - unflag those ops as being repeated ops | |||
| // - remove them from the eoe op stack so that repeat op above in the tree won't know about them | |||
| // - add ourself (the cache op), as an eoe op | |||
| // We do this so that those old leafs become 1-time use (up to eoe), never repeated. Instead | |||
| // the repeating behaviours shall be invoked against the cache op. | |||
| std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack(); | |||
| while (leaf_op != nullptr) { | |||
| leaf_op->ClearControlFlag(DatasetOp::kDeOpLastRepeat); | |||
| leaf_op->ClearControlFlag(DatasetOp::kDeOpRepeated); | |||
| leaf_op = PopFromEOEOpStack(); | |||
| } | |||
| AddToEOEOpStack(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up | |||
| // for use with a controlling repeat above it. | |||
| Status RepeatPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { | |||
| // If we are in a repeat path, then set our repeated flag | |||
| if (is_repeated_) { | |||
| node->set_control_flag(DatasetOp::kDeOpRepeated); | |||
| // if we are a leaf node then save ourself in a stack for the repeat operator above us | |||
| if (node->IsLeaf()) { | |||
| AddToEOEOpStack(node); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Turns off the tracking for operations under merge op | |||
| Status RepeatPass::RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) { | |||
| // Setting the flag is needed since we didn't call the base class DatasetOp version | |||
| if (is_repeated_) node->set_control_flag(DatasetOp::kDeOpRepeated); | |||
| is_merge_ = false; | |||
| cache_lookup_.reset(); // If a repeat op did not consume this then it's no longer needed | |||
| return Status::OK(); | |||
| } | |||
| // Saves the lookup up in case it needs to be referenced by a repeat | |||
| Status RepeatPass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified) { | |||
| if (!node->IsLeaf()) { | |||
| // By definition, the CacheLookup must be a leaf op. Make that clear here. | |||
| RETURN_STATUS_UNEXPECTED("CacheLookupOp must be a leaf node!"); | |||
| } | |||
| // If we are in a repeat path already, then there must be a repeat above the merge op | |||
| // In this case, we naturally are a repeating leaf op so add the required setup for leafs under repeat here. | |||
| if (is_repeated_) { | |||
| node->set_control_flag(DatasetOp::kDeOpRepeated); | |||
| AddToEOEOpStack(node); | |||
| } else { | |||
| // save the lookup op. There could be a repeat in the cache miss leg of the merge op, in which case we | |||
| // may still need to be flagged as a repeating leaf. We can't decide that here though, so save ourself | |||
| // into the pass so that the decision can be made during the processing of the cache miss leg of the merge. | |||
| cache_lookup_ = std::static_pointer_cast<DatasetOp>(node); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Adds an operator to the eoe operator stack save area | |||
| void RepeatPass::AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op) { eoe_stack_.push(dataset_op); } | |||
| // Pops an operator from the eoe operator stack save area | |||
| std::shared_ptr<DatasetOp> RepeatPass::PopFromEOEOpStack() { | |||
| std::shared_ptr<DatasetOp> top_op = nullptr; | |||
| if (!eoe_stack_.empty()) { | |||
| top_op = eoe_stack_.top(); | |||
| eoe_stack_.pop(); | |||
| } | |||
| return top_op; | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,98 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_ | |||
| #define DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_ | |||
| #include <memory> | |||
| #include <stack> | |||
| #include <utility> | |||
| #include "dataset/engine/opt/pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// \class RepeatPass repeat_pass.h | |||
| /// \brief This is a NodePass who's job is to perform setup actions for RepeatOps. A RepeatOp needs to have references | |||
| /// to the eoe-producing (typically leaf) nodes underneath it. | |||
| class RepeatPass : public NodePass { | |||
| public: | |||
| /// \brief Constructor | |||
| RepeatPass(); | |||
| /// \brief Identifies the subtree below this node as being in a repeated path of the tree. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override; | |||
| /// \brief Identifies the subtree below this node as being in a cache merge path | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) override; | |||
| /// \brief Hooks up any identified eoe nodes under this repeat. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override; | |||
| /// \brief CacheOp removes previous leaf ops and replaces them with itself | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||
| /// \brief Turns of the tracking for operations under merge op | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) override; | |||
| /// \brief Saves the lookup up in case it needs to be referenced by a repeat | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified) override; | |||
| /// \brief All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up | |||
| /// for use with a controlling repeat above it. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) override; | |||
| private: | |||
| /// \brief Adds an operator to the eoe operator stack save area | |||
| /// \param op - The dataset op to work add to eoe stack | |||
| /// \return Status - The error code return | |||
| void AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op); | |||
| /// \brief Pops an operator from the eoe operator stack save area | |||
| /// \return shared_ptr to the popped operator | |||
| std::shared_ptr<DatasetOp> PopFromEOEOpStack(); | |||
| bool is_repeated_; // T/F if we are processing under a repeat | |||
| bool is_merge_; // T/F if we are processing under a cache merge op | |||
| int32_t nested_repeats_; // A counter for nested repeats | |||
| std::stack<std::shared_ptr<DatasetOp>> eoe_stack_; // A save area for leaf/eoe ops | |||
| std::shared_ptr<DatasetOp> cache_lookup_; // A save area for a cache lookup op | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_ | |||
| @@ -0,0 +1,181 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <memory> | |||
| #include "dataset/engine/opt/pre/cache_pass.h" | |||
| #include "dataset/engine/opt/pre/cache_transform_pass.h" | |||
| #include "dataset/engine/datasetops/cache_op.h" | |||
| #include "dataset/engine/datasetops/source/celeba_op.h" | |||
| #include "dataset/engine/datasetops/source/generator_op.h" | |||
| #include "dataset/engine/datasetops/source/manifest_op.h" | |||
| #include "dataset/engine/datasetops/source/mnist_op.h" | |||
| #include "dataset/engine/datasetops/source/voc_op.h" | |||
| #include "dataset/engine/datasetops/source/cifar_op.h" | |||
| #include "dataset/engine/datasetops/source/coco_op.h" | |||
| #include "dataset/engine/datasetops/source/image_folder_op.h" | |||
| #include "dataset/engine/datasetops/source/random_data_op.h" | |||
| #include "dataset/engine/datasetops/source/tf_reader_op.h" | |||
| #include "dataset/engine/datasetops/source/mindrecord_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Constructor | |||
| CachePass::CachePass(CacheTransformPass *transform_pass) | |||
| : transform_pass_(transform_pass), is_caching_(false), leaf_op_(nullptr) {} | |||
| // Identifies the subtree below this node as a cached descendant tree. | |||
| Status CachePass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||
| *modified = false; | |||
| MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree."; | |||
| if (is_caching_) { | |||
| RETURN_STATUS_UNEXPECTED("Nested cache operations is not supported!"); | |||
| } | |||
| is_caching_ = true; | |||
| return Status::OK(); | |||
| } | |||
| // Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache | |||
| // transformation | |||
| Status CachePass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||
| *modified = false; | |||
| is_caching_ = false; // We a no longer in a cache subtree. clear the flag. | |||
| if (leaf_op_) { | |||
| MS_LOG(INFO) << "Cache transform pass: Set up transformation nodes for mappable cache."; | |||
| // Assign the leaf op into the transform pass, using move to null our copy of it, and also assign the cache op, | |||
| // using base class pointers. | |||
| transform_pass_->AddMappableCacheOperators(std::move(leaf_op_), node); | |||
| } else { | |||
| // If there was no leaf_op set, then this is a non-mappable scenario. | |||
| if (sampler_) { | |||
| // Grab the sampler that was saved from the leaf and plug it into the cache op | |||
| node->SetSampler(std::move(sampler_)); | |||
| MS_LOG(INFO) << "Cache transform pass: Set up cache sampler from non-mappable leaf."; | |||
| } else { | |||
| // We're a cache op but no sampler was saved from leaf, so create a default sampler | |||
| int64_t num_samples = 0; | |||
| int64_t start_index = 0; | |||
| sampler_ = std::make_shared<SequentialSampler>(num_samples, start_index); | |||
| node->SetSampler(std::move(sampler_)); | |||
| MS_LOG(INFO) << "Cache transform pass: Creating default sequential sampler for cache op."; | |||
| } | |||
| // Get the computed check sum from all ops in our cache path below us and ask the cache op to create it's cache | |||
| uint32_t cache_crc = DatasetOp::GenerateCRC(node); | |||
| RETURN_IF_NOT_OK(node->CreateCache(cache_crc)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Common code for mappable leaf setup. | |||
| Status CachePass::MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) { | |||
| // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. | |||
| if (is_caching_ && leaf_op_) { | |||
| RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache."); | |||
| } | |||
| // If we are a leaf in the caching path, then save this leaf. | |||
| if (is_caching_) { | |||
| MS_LOG(DEBUG) << "Cache transform pass: Mappable leaf in a cache descendant tree detected"; | |||
| leaf_op_ = std::move(leaf_op); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Common code for non mappable leaf setup. | |||
| Status CachePass::NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) { | |||
| // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. | |||
| if (is_caching_ && leaf_op_) { | |||
| RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache."); | |||
| } | |||
| // Sampler for non mapable dataset only works if there is a downstream cache. Remove it from the leaf | |||
| // as save it for use by cache op in ascendant tree. | |||
| if (is_caching_) { | |||
| RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_)); | |||
| MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected"; | |||
| } else { | |||
| // If we are a non-mappable leaf and are not in a cache tree, then this sampler is not used so we can | |||
| // remove it here. The leaf itself will provide it's own methods of fetching the data (not sampler-based) | |||
| std::shared_ptr<Sampler> sampler_from_leaf; | |||
| RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_from_leaf)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Perform leaf node cache tranform identifications | |||
| Status CachePass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) { | |||
| if (is_caching_) { | |||
| // If we are a TF Reader in a caching tree, then change our config so that it becomes a basic | |||
| // TF reader that parses all files. Selection of data will come from the sampler on the cache instead. | |||
| node->MakeSimpleProducer(); | |||
| } | |||
| return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache tranform identifications | |||
| Status CachePass::RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) { | |||
| return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache tranform identifications | |||
| Status CachePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache tranform identifications | |||
| Status CachePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache tranform identifications | |||
| Status CachePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache tranform identifications | |||
| Status CachePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache tranform identifications | |||
| Status CachePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache tranform identifications | |||
| Status CachePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache tranform identifications | |||
| Status CachePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache tranform identifications | |||
| Status CachePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache tranform identifications | |||
| Status CachePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,138 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_H_ | |||
| #define DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include "dataset/engine/opt/pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class CacheTransformPass; | |||
| /// \class CachePass cache_pass.h | |||
| /// \brief This is a NodePass who's job is to identify and set up the nodes that will be involved in a cache | |||
| /// transformation. It works in conjunction with the CacheTransformPass | |||
| class CachePass : public NodePass { | |||
| public: | |||
| /// \brief Constructor | |||
| /// \param[in] transform_pass Raw pointer back to controlling tree pass | |||
| explicit CachePass(CacheTransformPass *transform_pass); | |||
| /// \brief Identifies the subtree below this node as a cached descendant tree. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||
| /// \brief Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache | |||
| /// transformation | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override; | |||
| private: | |||
| /// \brief Common code for mappable leaf setup. | |||
| /// \param[in] node The leaf node performing setup work. | |||
| /// \return Status The error code return | |||
| Status MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op); | |||
| /// \brief Common code for non-mappable leaf setup. | |||
| /// \param[in] node The leaf node performing setup work. | |||
| /// \return Status The error code return | |||
| Status NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op); | |||
| bool is_caching_; | |||
| std::shared_ptr<DatasetOp> leaf_op_; | |||
| std::shared_ptr<Sampler> sampler_; | |||
| CacheTransformPass *transform_pass_; // Back pointer to the owning transform pass | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_ | |||
| @@ -0,0 +1,108 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include "dataset/engine/opt/pre/cache_pass.h" | |||
| #include "dataset/engine/opt/pre/cache_transform_pass.h" | |||
| #include "dataset/engine/execution_tree.h" | |||
| #include "dataset/engine/cache/cache_client.h" | |||
| #include "dataset/engine/datasetops/cache_lookup_op.h" | |||
| #include "dataset/engine/datasetops/cache_merge_op.h" | |||
| #include "dataset/engine/datasetops/cache_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // constructor | |||
| CacheTransformPass::CacheTransformPass() {} | |||
| // Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations | |||
| Status CacheTransformPass::RunOnTree(ExecutionTree *tree, bool *modified) { | |||
| MS_LOG(INFO) << "Pre pass: Cache transform pass started."; | |||
| // Create the cache pass and run it. The cache pass identifies and creates the leaf/cache pairs that we will | |||
| // use to execute a transform. | |||
| std::unique_ptr<Pass> cache_pass = std::make_unique<CachePass>(this); | |||
| RETURN_IF_NOT_OK(cache_pass->Run(tree, modified)); | |||
| // Then, execute the transform for each pair | |||
| for (auto cache_pair : cache_pairs_) { | |||
| MS_LOG(DEBUG) << "Cache transform pass: Executing a cache op mappable transform."; | |||
| ExecuteCacheTransform(tree, cache_pair.first, cache_pair.second, cache_pair.second->cache_client()); | |||
| } | |||
| MS_LOG(INFO) << "Pre pass: Cache transform pass complete."; | |||
| return Status::OK(); | |||
| } | |||
| // Helper function to execute the cache transformation. | |||
| Status CacheTransformPass::ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr<DatasetOp> leaf_op, | |||
| std::shared_ptr<DatasetOp> cache_op, | |||
| std::shared_ptr<CacheClient> cache_client) { | |||
| // Get local pointers the child/parent of the cache op. It's possible that the parent is null if the cache was | |||
| // the root node. It is also possible that cache_child == leaf_op | |||
| std::shared_ptr<DatasetOp> cache_child = cache_op->child(0); | |||
| DatasetOp *cache_parent = nullptr; | |||
| cache_op->Parent(&cache_parent, 0); // fetch the cache op's parent | |||
| // Extract the sampler from the leaf. We will overwrite this sampler with the lookup op later. | |||
| std::shared_ptr<Sampler> leaf_sampler = leaf_op->sampler(); | |||
| // Construct the merge op with defaults | |||
| std::shared_ptr<CacheMergeOp> merge_op; | |||
| CacheMergeOp::Builder merge_builder; | |||
| RETURN_IF_NOT_OK(merge_builder.SetClient(cache_client).Build(&merge_op)); | |||
| RETURN_IF_NOT_OK(tree->AssociateNode(merge_op)); | |||
| // Construct the cache lookup op with defaults | |||
| std::shared_ptr<CacheLookupOp> cache_lookup_op; | |||
| CacheLookupOp::Builder lookup_builder; | |||
| RETURN_IF_NOT_OK(lookup_builder.SetClient(cache_client).SetSampler(std::move(leaf_sampler)).Build(&cache_lookup_op)); | |||
| RETURN_IF_NOT_OK(tree->AssociateNode(cache_lookup_op)); | |||
| // Overwrite the old sampler in this leaf op to become the lookup op | |||
| leaf_op->SetSampler(cache_lookup_op); | |||
| // If the cache had a parent, then go into that parent to remove the cache from it's child list and then | |||
| // replace it with the merge op. | |||
| if (cache_parent != nullptr) { | |||
| RETURN_IF_NOT_OK(cache_parent->RemoveChild(cache_op)); | |||
| RETURN_IF_NOT_OK(cache_parent->AddChild(merge_op)); | |||
| } else { | |||
| // If we didn't have a parent, then the merge op is the root node | |||
| RETURN_IF_NOT_OK(tree->AssignRoot(merge_op)); | |||
| } | |||
| // Set the cache op to no longer be a parent over it's child. This will fully disconnect the old cache op. | |||
| // We maintain a local pointer to the old child though. | |||
| RETURN_IF_NOT_OK(cache_op->RemoveChild(cache_child)); | |||
| // Connect the merge op | |||
| RETURN_IF_NOT_OK(merge_op->AddChild(std::move(cache_lookup_op))); | |||
| RETURN_IF_NOT_OK(merge_op->AddChild(std::move(cache_child))); | |||
| // At this point, the cache op has already had it's children and parents taken away. Calling remove | |||
| // on it at this point will not do any node hookups, and instead set internal fields to invalid. | |||
| RETURN_IF_NOT_OK(cache_op->Remove()); | |||
| return Status::OK(); | |||
| } | |||
| // Assigns the leaf and cache operators that are involved in a cache transformation | |||
| void CacheTransformPass::AddMappableCacheOperators(std::shared_ptr<DatasetOp> leaf_op, | |||
| std::shared_ptr<CacheOp> cache_op) { | |||
| cache_pairs_.push_back(std::make_pair(leaf_op, cache_op)); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,79 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_OPT_PASS_PRE_CACHE_TRANSFORM_PASS_H_ | |||
| #define DATASET_ENGINE_OPT_PASS_PRE_CACHE_TRANSFORM_PASS_H_ | |||
| #include <memory> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "dataset/engine/opt/pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class DatasetOp; | |||
| class CacheClient; | |||
| /// \class CacheTransformPass cache_transform_pass.h | |||
| /// \brief This is a tree pass that will invoke a tree transformation to inject the correct operators for caching | |||
| /// operations | |||
| class CacheTransformPass : public TreePass { | |||
| public: | |||
| /// \brief Constructor | |||
| CacheTransformPass(); | |||
| /// \brief Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations | |||
| /// \param[inout] tree The tree to operate on. | |||
| /// \param[inout] Indicate of the tree was modified. | |||
| /// \return Status The error code return | |||
| Status RunOnTree(ExecutionTree *tree, bool *modified) override; | |||
| /// \brief Assigns the leaf and cache operators that are involved in a cache transformation | |||
| /// \param[in] leaf_op The leaf operator involved in the cache transform | |||
| /// \param[in] cache_op The cache operator involved in the cache transform | |||
| void AddMappableCacheOperators(std::shared_ptr<DatasetOp> leaf_op, std::shared_ptr<CacheOp> cache_op); | |||
| private: | |||
| /// \brief Helper function to execute the cache transformation. | |||
| /// | |||
| /// Input: | |||
| /// Sampler | |||
| /// | | |||
| /// LeafOp --> OtherOps --> CacheOp | |||
| /// | |||
| /// Transformed: | |||
| /// Sampler --> CacheLookupOp ----------------> | |||
| /// | | | |||
| /// | MergeOp | |||
| /// | | | |||
| /// LeafOp --> OtherOps --> | |||
| /// | |||
| /// \param[in] leaf_op The leaf node in the transform | |||
| /// \param[in] cache_op The cache op in the transform (will get removed) | |||
| /// \param[in] cache_client The cache client | |||
| /// \return Status The error code return | |||
| Status ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr<DatasetOp> leaf_op, | |||
| std::shared_ptr<DatasetOp> cache_op, std::shared_ptr<CacheClient> cache_client); | |||
| // The two operators that work together to establish the cache transform | |||
| std::vector<std::pair<std::shared_ptr<DatasetOp>, std::shared_ptr<CacheOp>>> cache_pairs_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_OPT_PASS_PRE_CACHE_TRANSFORM_PASS_H_ | |||
| @@ -24,12 +24,28 @@ namespace dataset { | |||
| RemovalNodes::RemovalNodes(RemovalPass *removal_pass) : removal_pass_(removal_pass), is_caching_(false) {} | |||
| // Identifies the subtree below this node as a cached descendant tree. | |||
| Status RemovalNodes::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||
| *modified = false; | |||
| MS_LOG(INFO) << "Removal pass: CacheOp found, identified descendant tree."; | |||
| is_caching_ = true; | |||
| return Status::OK(); | |||
| } | |||
| // Resets the tracking of the cache within the tree | |||
| Status RemovalNodes::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||
| *modified = false; | |||
| MS_LOG(INFO) << "Removal pass: cache descendant tree complete."; | |||
| is_caching_ = false; | |||
| return Status::OK(); | |||
| } | |||
| // Perform ShuffleOp removal check. | |||
| Status RemovalNodes::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) { | |||
| *modified = false; | |||
| // If we are in a cache descendant tree, then this shuffle op needs to be removed | |||
| if (is_caching_) { | |||
| MS_LOG(DEBUG) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)"; | |||
| MS_LOG(INFO) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)"; | |||
| if (removal_pass_) { | |||
| removal_pass_->AddToRemovalList(std::static_pointer_cast<DatasetOp>(node)); | |||
| } else { | |||
| @@ -34,6 +34,18 @@ class RemovalNodes : public NodePass { | |||
| /// \param[in] removal_pass Raw pointer back to controlling tree pass | |||
| explicit RemovalNodes(RemovalPass *removal_pass); | |||
| /// \brief Identifies the subtree below this node as a cached descendant tree. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||
| /// \brief Resets the tracking of the cache within the tree | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||
| /// \brief Perform ShuffleOp removal check | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| @@ -28,6 +28,7 @@ RemovalPass::RemovalPass() {} | |||
| // Runs a removal_nodes pass first to find out which nodes to remove, then removes them. | |||
| Status RemovalPass::RunOnTree(ExecutionTree *tree, bool *modified) { | |||
| MS_LOG(INFO) << "Pre pass: removal pass started."; | |||
| // Create the removal node pass which can identify which nodes need to be removed. | |||
| std::unique_ptr<Pass> removal_nodes = std::make_unique<RemovalNodes>(this); | |||
| RETURN_IF_NOT_OK(removal_nodes->Run(tree, modified)); | |||
| @@ -36,6 +37,7 @@ Status RemovalPass::RunOnTree(ExecutionTree *tree, bool *modified) { | |||
| for (auto node : removal_nodes_) { | |||
| node->Remove(); | |||
| } | |||
| MS_LOG(INFO) << "Pre pass: removal pass complete."; | |||
| return Status::OK(); | |||
| } | |||
| @@ -87,8 +87,9 @@ class Allocator { | |||
| std::shared_ptr<MemoryPool> pool_; | |||
| }; | |||
| /// \brief It is a wrapper of unique_ptr with a custom allocator and acts like std::lock_guard such that the memory will | |||
| /// be released when the object goes out of scope \tparam T The type of object to be allocated \tparam C Allocator. | |||
| /// Default to std::allocator | |||
| /// be released when the object goes out of scope | |||
| /// \tparam T The type of object to be allocated | |||
| /// \tparam C Allocator. Default to std::allocator | |||
| template <typename T, typename C = std::allocator<T>> | |||
| class MemGuard { | |||
| public: | |||
| @@ -168,7 +169,7 @@ class MemGuard { | |||
| private: | |||
| allocator alloc_; | |||
| std::unique_ptr<T[], std::function<void(T *)>> ptr_; | |||
| std::unique_ptr<T[]> ptr_; | |||
| size_t n_; | |||
| }; | |||
| } // namespace dataset | |||
| @@ -98,11 +98,6 @@ Status CachePool::Insert(const std::vector<ReadableSlice> &buf, CachePool::key_t | |||
| } catch (std::bad_alloc &e) { | |||
| if (sm_ != nullptr) { | |||
| RETURN_IF_NOT_OK(sm_->Write(&bl.storage_key, buf)); | |||
| // We have an assumption 0 is not a valid key from the design of AutoIndexObj. | |||
| // Make sure it is not 0. | |||
| if (bl.storage_key == 0) { | |||
| RETURN_STATUS_UNEXPECTED("Key 0 is returned which is unexpected"); | |||
| } | |||
| } else { | |||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); | |||
| } | |||
| @@ -22,11 +22,11 @@ | |||
| #include <stdlib.h> | |||
| #endif | |||
| #include <unistd.h> | |||
| #include "dataset/engine/cache/cache_server.h" | |||
| #include "dataset/util/circular_pool.h" | |||
| #include "dataset/util/random.h" | |||
| #include "dataset/util/task_manager.h" | |||
| #define SLOT_TASK_MGR 0 | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| std::unique_ptr<Services> Services::instance_ = nullptr; | |||
| @@ -61,15 +61,25 @@ std::string Services::GetUniqueID() { | |||
| TaskManager &Services::getTaskMgrInstance() { | |||
| Services &sm = GetInstance(); | |||
| return *(static_cast<TaskManager *>(sm.sa_[SLOT_TASK_MGR])); | |||
| return *(static_cast<TaskManager *>(sm.sa_[kSlotTaskMgr_])); | |||
| } | |||
| CacheServer &Services::getCacheServer() { | |||
| Services &sm = GetInstance(); | |||
| return *(static_cast<CacheServer *>(sm.sa_[kSlotCacheMgr_])); | |||
| } | |||
| Status Services::CreateAllInstances() { | |||
| // In order, TaskMgr, BufferMgr | |||
| Status rc; | |||
| sa_[SLOT_TASK_MGR] = new (&rc, pool_) TaskManager(); | |||
| sa_[kSlotTaskMgr_] = new (&rc, pool_) TaskManager(); | |||
| RETURN_IF_NOT_OK(rc); | |||
| rc = sa_[SLOT_TASK_MGR]->ServiceStart(); | |||
| rc = sa_[kSlotTaskMgr_]->ServiceStart(); | |||
| RETURN_IF_NOT_OK(rc); | |||
| // TODO(jesse) : Get the parameters from config file. Right now spill to /tmp and spawn 3 workers | |||
| sa_[kSlotCacheMgr_] = new (&rc, pool_) CacheServer("/tmp", 3); | |||
| RETURN_IF_NOT_OK(rc); | |||
| rc = sa_[kSlotCacheMgr_]->ServiceStart(); | |||
| return rc; | |||
| } | |||
| @@ -83,8 +93,14 @@ Services::Services() : pool_(nullptr), sa_{nullptr} { | |||
| Services::~Services() noexcept { | |||
| try { | |||
| // In reverse order | |||
| TaskManager *tm = static_cast<TaskManager *>(sa_[SLOT_TASK_MGR]); | |||
| if (tm) { | |||
| CacheServer *cs = static_cast<CacheServer *>(sa_[kSlotCacheMgr_]); | |||
| if (cs != nullptr) { | |||
| (void)cs->ServiceStop(); | |||
| cs->~CacheServer(); | |||
| pool_->Deallocate(cs); | |||
| } | |||
| TaskManager *tm = static_cast<TaskManager *>(sa_[kSlotTaskMgr_]); | |||
| if (tm != nullptr) { | |||
| (void)tm->ServiceStop(); | |||
| tm->~TaskManager(); | |||
| pool_->Deallocate(tm); | |||
| @@ -27,7 +27,7 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class TaskManager; | |||
| class CacheServer; | |||
| class Services { | |||
| public: | |||
| static Status CreateInstance() { | |||
| @@ -61,6 +61,8 @@ class Services { | |||
| static TaskManager &getTaskMgrInstance(); | |||
| static CacheServer &getCacheServer(); | |||
| std::shared_ptr<MemoryPool> GetServiceMemPool() { return pool_; } | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| @@ -87,7 +89,9 @@ class Services { | |||
| // We use pointers here instead of unique_ptr because we | |||
| // want to have ultimate control on the order of | |||
| // construction and destruction. | |||
| static constexpr int kNumServices_ = 1; | |||
| static constexpr int kSlotTaskMgr_ = 0; | |||
| static constexpr int kSlotCacheMgr_ = 1; | |||
| static constexpr int kNumServices_ = 2; | |||
| Service *sa_[kNumServices_]; | |||
| Services(); | |||
| @@ -24,6 +24,7 @@ from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset | |||
| TextFileDataset, CLUEDataset, Schema, Shuffle, zip, RandomDataset | |||
| from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \ | |||
| WeightedRandomSampler, Sampler | |||
| from .engine.cache_client import DatasetCache | |||
| from .engine.serializer_deserializer import serialize, deserialize, show | |||
| from .engine.graphdata import GraphData | |||
| @@ -0,0 +1,49 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Cache client | |||
| """ | |||
| import copy | |||
| from mindspore._c_dataengine import CacheClient | |||
| class DatasetCache: | |||
| """ | |||
| A client to interface with tensor caching service | |||
| """ | |||
| def __init__(self, session_id=None, size=None, spilling=False): | |||
| if session_id is None: | |||
| raise RuntimeError("Session generation is not implemented yet. session id required") | |||
| self.size = size if size is not None else 0 | |||
| if size < 0: | |||
| raise ValueError("cache size should be 0 or positive integer value but got: size={}".format(size)) | |||
| if not isinstance(spilling, bool): | |||
| raise ValueError( | |||
| "spilling argument for cache should be a boolean value but got: spilling={}".format(spilling)) | |||
| self.session_id = session_id | |||
| self.spilling = spilling | |||
| self.cache_client = CacheClient(session_id, size, spilling) | |||
| def __deepcopy__(self, memodict): | |||
| if id(self) in memodict: | |||
| return memodict[id(self)] | |||
| cls = self.__class__ | |||
| new_cache = cls.__new__(cls) | |||
| memodict[id(self)] = new_cache | |||
| new_cache.session_id = copy.deepcopy(self.session_id, memodict) | |||
| new_cache.spilling = copy.deepcopy(self.spilling, memodict) | |||
| new_cache.size = copy.deepcopy(self.size, memodict) | |||
| new_cache.cache_client = self.cache_client | |||
| return new_cache | |||
| @@ -44,7 +44,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che | |||
| check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ | |||
| check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \ | |||
| check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \ | |||
| check_split, check_bucket_batch_by_length, check_cluedataset, check_positive_int32 | |||
| check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_positive_int32 | |||
| from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist | |||
| try: | |||
| @@ -386,7 +386,7 @@ class Dataset: | |||
| @check_map | |||
| def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None, | |||
| num_parallel_workers=None, python_multiprocessing=False): | |||
| num_parallel_workers=None, python_multiprocessing=False, cache=None): | |||
| """ | |||
| Apply each operation in operations to this dataset. | |||
| @@ -427,6 +427,7 @@ class Dataset: | |||
| parallel (default=None, the value from the config will be used). | |||
| python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This | |||
| option could be beneficial if the python operation is computational heavy (default=False). | |||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used) | |||
| Returns: | |||
| MapDataset, dataset after mapping operation. | |||
| @@ -541,7 +542,7 @@ class Dataset: | |||
| >>> ds_mapped = ds_pyfunc.map(input_columns, operations, output_columns, columns_order) | |||
| """ | |||
| return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers, | |||
| python_multiprocessing) | |||
| python_multiprocessing, cache) | |||
| @check_filter | |||
| def filter(self, predicate, input_columns=None, num_parallel_workers=1): | |||
| @@ -1868,13 +1869,14 @@ class MapDataset(DatasetOp): | |||
| in parallel (default=None). | |||
| python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This | |||
| option could be beneficial if the python operation is computational heavy (default=False). | |||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used) | |||
| Raises: | |||
| ValueError: If len(input_columns) != len(output_columns) and columns_order is not specified. | |||
| """ | |||
| def __init__(self, input_dataset, input_columns=None, operations=None, output_columns=None, columns_order=None, | |||
| num_parallel_workers=None, python_multiprocessing=False): | |||
| num_parallel_workers=None, python_multiprocessing=False, cache=None): | |||
| super().__init__(num_parallel_workers) | |||
| self.children.append(input_dataset) | |||
| if input_columns is not None and not isinstance(input_columns, list): | |||
| @@ -1886,6 +1888,7 @@ class MapDataset(DatasetOp): | |||
| if output_columns is not None and not isinstance(output_columns, list): | |||
| output_columns = [output_columns] | |||
| self.output_columns = output_columns | |||
| self.cache = cache | |||
| self.columns_order = columns_order | |||
| if self.input_columns and self.output_columns \ | |||
| @@ -1904,6 +1907,7 @@ class MapDataset(DatasetOp): | |||
| args["operations"] = self.operations | |||
| args["output_columns"] = self.output_columns | |||
| args["columns_order"] = self.columns_order | |||
| args["cache"] = self.cache.cache_client if self.cache is not None else None | |||
| return args | |||
| def get_dataset_size(self): | |||
| @@ -1929,6 +1933,7 @@ class MapDataset(DatasetOp): | |||
| new_op.parent = copy.deepcopy(self.parent, memodict) | |||
| new_op.input_indexs = copy.deepcopy(self._input_indexs, memodict) | |||
| new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict) | |||
| new_op.cache = copy.deepcopy(self.cache, memodict) | |||
| new_op.operations = self.operations | |||
| return new_op | |||
| @@ -2346,7 +2351,7 @@ class RangeDataset(MappableDataset): | |||
| return False | |||
| def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id): | |||
| def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id, non_mappable=False): | |||
| """ | |||
| Create sampler based on user input. | |||
| @@ -2356,7 +2361,11 @@ def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id): | |||
| shuffle (bool): Shuffle. | |||
| num_shards (int): Number of shard for sharding. | |||
| shard_id (int): Shard ID. | |||
| non_mappable (bool, optional): Indicate if caller is non-mappable dataset for special handling (default=False). | |||
| """ | |||
| if non_mappable is True and all(arg is None for arg in [num_samples, shuffle, num_shards, shard_id, input_sampler]): | |||
| return None | |||
| if input_sampler is not None: | |||
| # If the user provided a sampler, then it doesn't matter what the other args are because | |||
| # we are being asked specifically to use the given sampler. | |||
| @@ -2369,7 +2378,7 @@ def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id): | |||
| if (isinstance(input_sampler, (samplers.SequentialSampler, samplers.DistributedSampler, | |||
| samplers.RandomSampler, samplers.SubsetRandomSampler, | |||
| samplers.WeightedRandomSampler, samplers.Sampler)) and | |||
| (num_shards is not None or shard_id is not None or shuffle is not None or num_samples is not None)): | |||
| (any(arg is not None for arg in [num_shards, shard_id, shuffle, num_samples]))): | |||
| raise ValueError( | |||
| 'Conflicting arguments during sampler assignments. num_samples: {}, num_shards: {},' | |||
| ' shard_id: {}, shuffle: {})'.format(num_samples, num_shards, shard_id, shuffle)) | |||
| @@ -2458,6 +2467,7 @@ class ImageFolderDatasetV2(MappableDataset): | |||
| into (default=None). | |||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | |||
| argument should be specified only when num_shards is also specified. | |||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used) | |||
| Raises: | |||
| RuntimeError: If sampler and shuffle are specified at the same time. | |||
| @@ -2482,7 +2492,7 @@ class ImageFolderDatasetV2(MappableDataset): | |||
| @check_imagefolderdatasetv2 | |||
| def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None, | |||
| shuffle=None, sampler=None, extensions=None, class_indexing=None, | |||
| decode=False, num_shards=None, shard_id=None): | |||
| decode=False, num_shards=None, shard_id=None, cache=None): | |||
| super().__init__(num_parallel_workers) | |||
| self.dataset_dir = dataset_dir | |||
| @@ -2494,6 +2504,7 @@ class ImageFolderDatasetV2(MappableDataset): | |||
| self.decode = decode | |||
| self.num_shards = num_shards | |||
| self.shard_id = shard_id | |||
| self.cache = cache | |||
| def get_args(self): | |||
| args = super().get_args() | |||
| @@ -2506,6 +2517,7 @@ class ImageFolderDatasetV2(MappableDataset): | |||
| args["decode"] = self.decode | |||
| args["num_shards"] = self.num_shards | |||
| args["shard_id"] = self.shard_id | |||
| args["cache"] = self.cache.cache_client if self.cache is not None else None | |||
| return args | |||
| def get_dataset_size(self): | |||
| @@ -3251,6 +3263,7 @@ class TFRecordDataset(SourceDataset): | |||
| argument should be specified only when num_shards is also specified. | |||
| shard_equal_rows (bool): Get equal rows for all shards(default=False). If shard_equal_rows is false, number | |||
| of rows of each shard may be not equal. | |||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used) | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| >>> import mindspore.common.dtype as mstype | |||
| @@ -3268,7 +3281,7 @@ class TFRecordDataset(SourceDataset): | |||
| @check_tfrecorddataset | |||
| def __init__(self, dataset_files, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None, | |||
| shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, shard_equal_rows=False): | |||
| shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, shard_equal_rows=False, cache=None): | |||
| super().__init__(num_parallel_workers) | |||
| self.dataset_files = self._find_files(dataset_files) | |||
| self.dataset_files.sort() | |||
| @@ -3280,6 +3293,7 @@ class TFRecordDataset(SourceDataset): | |||
| self.schema = schema | |||
| self.columns_list = columns_list | |||
| self.num_samples = num_samples | |||
| self.cache = cache | |||
| if schema_obj is not None and num_samples is None: | |||
| self.num_samples = schema_obj.num_rows | |||
| @@ -3295,6 +3309,14 @@ class TFRecordDataset(SourceDataset): | |||
| else: | |||
| self.shuffle_level = shuffle | |||
| self.shuffle_files = True | |||
| # The TF record dataset does not directly support a sampler. It has provided sampling arguments | |||
| # (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in | |||
| # the pipeline contains a cache. If there is no cache above it, then this sampler is not used. | |||
| sampler_shuffle = self.shuffle_files | |||
| sampler = None | |||
| self.sampler = _select_sampler(self.num_samples, sampler, sampler_shuffle, num_shards, shard_id, | |||
| non_mappable=True) | |||
| self.shard_equal_rows = shard_equal_rows | |||
| def get_args(self): | |||
| @@ -3318,6 +3340,8 @@ class TFRecordDataset(SourceDataset): | |||
| args["num_shards"] = self.num_shards | |||
| args["shard_id"] = self.shard_id | |||
| args["shard_equal_rows"] = self.shard_equal_rows | |||
| args["cache"] = self.cache.cache_client if self.cache is not None else None | |||
| args["sampler"] = self.sampler | |||
| return args | |||
| def get_dataset_size(self, estimate=False): | |||
| @@ -3803,43 +3827,61 @@ class RandomDataset(SourceDataset): | |||
| A source dataset that generates random data. | |||
| Args: | |||
| num_samples (int): number of samples to generate. | |||
| total_rows (int): number of rows for the dataset to generate (default=None, number of rows is random) | |||
| schema (str or Schema, optional): Path to the json schema file or schema object (default=None). | |||
| If the schema is not provided, the random dataset generates a random schema. | |||
| columns_list (list[str], optional): List of columns to be read (default=None, read all columns) | |||
| num_samples (int): number of samples to draw from the total. (default=None, which means all rows) | |||
| num_parallel_workers (int, optional): number of workers to read the data | |||
| (default=None, number set in the config). | |||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used) | |||
| shuffle (bool, optional): Whether or not to perform shuffle on the dataset | |||
| (default=None, expected order behavior shown in the table). | |||
| num_shards (int, optional): Number of shards that the dataset should be divided | |||
| into (default=None). | |||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | |||
| argument should be specified only when num_shards is also specified. | |||
| """ | |||
| def __init__(self, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None): | |||
| @check_random_dataset | |||
| def __init__(self, total_rows=None, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None, | |||
| cache=None, shuffle=None, num_shards=None, shard_id=None): | |||
| super().__init__(num_parallel_workers) | |||
| schema_obj = None | |||
| if (schema is not None) and (not isinstance(schema, Schema)): | |||
| schema_obj = Schema(schema) # read the schema file and convert to schema object to validate it | |||
| self.schema = schema | |||
| self.columns_list = columns_list | |||
| if schema_obj is not None and num_samples is None: | |||
| self.num_samples = schema_obj.num_rows | |||
| elif num_samples is None: | |||
| self.num_samples = 0 | |||
| sampler = None | |||
| self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id, non_mappable=True) | |||
| self.num_samples = num_samples | |||
| self.cache = cache | |||
| if schema_obj is not None and total_rows is None: | |||
| self.total_rows = schema_obj.num_rows | |||
| elif total_rows is None: | |||
| self.total_rows = 0 | |||
| else: | |||
| self.num_samples = num_samples | |||
| self.total_rows = total_rows | |||
| self.num_shards = num_shards | |||
| self.shard_id = shard_id | |||
| self.shuffle_level = shuffle | |||
| def get_args(self): | |||
| args = super().get_args() | |||
| if self.schema is not None: | |||
| if isinstance(self.schema, Schema): | |||
| self.schema.datasetType = 'Random' | |||
| if self.num_samples is not None: | |||
| self.schema.num_rows = self.num_samples | |||
| if self.total_rows is not None: | |||
| self.schema.num_rows = self.total_rows | |||
| args["schema_json_string"] = self.schema.to_json() | |||
| else: | |||
| args["schema_file_path"] = self.schema | |||
| args["schema"] = self.schema | |||
| if self.columns_list is not None: | |||
| args["columns_list"] = self.columns_list | |||
| if self.num_samples is not None: | |||
| args["num_samples"] = self.num_samples | |||
| args["columns_list"] = self.columns_list | |||
| args["num_samples"] = self.num_samples | |||
| args["total_rows"] = self.total_rows | |||
| args["cache"] = self.cache.cache_client if self.cache is not None else None | |||
| args["sampler"] = self.sampler | |||
| return args | |||
| def get_dataset_size(self): | |||
| @@ -3849,18 +3891,29 @@ class RandomDataset(SourceDataset): | |||
| Return: | |||
| Number, number of batches. | |||
| """ | |||
| num_rows = CifarOp.get_num_rows(self.dataset_dir, True) | |||
| rows_per_shard = get_num_rows(num_rows, self.num_shards) | |||
| rows_from_sampler = self._get_sampler_dataset_size() | |||
| if rows_from_sampler is None: | |||
| return self.num_samples | |||
| return rows_per_shard | |||
| return min(rows_from_sampler, self.num_samples) | |||
| return min(rows_from_sampler, rows_per_shard) | |||
| def is_shuffled(self): | |||
| return True | |||
| if self.shuffle_level is None: | |||
| return True | |||
| return self.shuffle_level or self.sampler.is_shuffled() | |||
| def is_sharded(self): | |||
| return False | |||
| if self.num_shards is not None: | |||
| return self.num_shards > 1 | |||
| return self.sampler.is_sharded() | |||
| class Schema: | |||
| @@ -173,7 +173,9 @@ def traverse(node): | |||
| # num_samples, shard_id, num_shards, shuffle | |||
| # These arguments get moved into the sampler itself, so they are no longer needed to | |||
| # be set at the dataset level. | |||
| if 'sampler' in node_args.keys(): | |||
| # TF Record is a special case because it uses both the dataset and sampler arguments | |||
| # which is not decided until later during tree preparation phase. | |||
| if node_repr['op_type'] != 'TFRecordDataset' and 'sampler' in node_args.keys(): | |||
| if 'num_samples' in node_repr.keys(): | |||
| node_repr['num_samples'] = None | |||
| if 'shuffle' in node_repr.keys(): | |||
| @@ -29,10 +29,11 @@ from ..core.validator_helpers import parse_user_args, type_check, type_check_lis | |||
| from . import datasets | |||
| from . import samplers | |||
| from . import cache_client | |||
| def check_imagefolderdatasetv2(method): | |||
| """A wrapper that wrap a parameter checker to the original Dataset(ImageFolderDatasetV2).""" | |||
| """A wrapper that wraps a parameter checker to the original Dataset(ImageFolderDatasetV2).""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| @@ -58,7 +59,7 @@ def check_imagefolderdatasetv2(method): | |||
| def check_mnist_cifar_dataset(method): | |||
| """A wrapper that wrap a parameter checker to the original Dataset(ManifestDataset, Cifar10/100Dataset).""" | |||
| """A wrapper that wraps a parameter checker to the original Dataset(ManifestDataset, Cifar10/100Dataset).""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| @@ -81,7 +82,7 @@ def check_mnist_cifar_dataset(method): | |||
| def check_manifestdataset(method): | |||
| """A wrapper that wrap a parameter checker to the original Dataset(ManifestDataset).""" | |||
| """A wrapper that wraps a parameter checker to the original Dataset(ManifestDataset).""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| @@ -108,7 +109,7 @@ def check_manifestdataset(method): | |||
| def check_tfrecorddataset(method): | |||
| """A wrapper that wrap a parameter checker to the original Dataset(TFRecordDataset).""" | |||
| """A wrapper that wraps a parameter checker to the original Dataset(TFRecordDataset).""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| @@ -134,7 +135,7 @@ def check_tfrecorddataset(method): | |||
| def check_vocdataset(method): | |||
| """A wrapper that wrap a parameter checker to the original Dataset(VOCDataset).""" | |||
| """A wrapper that wraps a parameter checker to the original Dataset(VOCDataset).""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| @@ -175,7 +176,7 @@ def check_vocdataset(method): | |||
| def check_cocodataset(method): | |||
| """A wrapper that wrap a parameter checker to the original Dataset(CocoDataset).""" | |||
| """A wrapper that wraps a parameter checker to the original Dataset(CocoDataset).""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| @@ -211,7 +212,7 @@ def check_cocodataset(method): | |||
| def check_celebadataset(method): | |||
| """A wrapper that wrap a parameter checker to the original Dataset(CelebADataset).""" | |||
| """A wrapper that wraps a parameter checker to the original Dataset(CelebADataset).""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| @@ -247,7 +248,7 @@ def check_celebadataset(method): | |||
| def check_minddataset(method): | |||
| """A wrapper that wrap a parameter checker to the original Dataset(MindDataset).""" | |||
| """A wrapper that wraps a parameter checker to the original Dataset(MindDataset).""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| @@ -279,7 +280,7 @@ def check_minddataset(method): | |||
| def check_generatordataset(method): | |||
| """A wrapper that wrap a parameter checker to the original Dataset(GeneratorDataset).""" | |||
| """A wrapper that wraps a parameter checker to the original Dataset(GeneratorDataset).""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| @@ -344,6 +345,27 @@ def check_generatordataset(method): | |||
| return new_method | |||
| def check_random_dataset(method): | |||
| """A wrapper that wraps a parameter checker to the original Dataset(RandomDataset).""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| _, param_dict = parse_user_args(method, *args, **kwargs) | |||
| nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id', 'total_rows'] | |||
| nreq_param_bool = ['shuffle'] | |||
| nreq_param_list = ['columns_list'] | |||
| validate_dataset_param_value(nreq_param_int, param_dict, int) | |||
| validate_dataset_param_value(nreq_param_bool, param_dict, bool) | |||
| validate_dataset_param_value(nreq_param_list, param_dict, list) | |||
| check_sampler_shuffle_shard_options(param_dict) | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| def check_pad_info(key, val): | |||
| """check the key and value pair of pad_info in batch""" | |||
| @@ -506,7 +528,7 @@ def check_map(method): | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| [input_columns, _, output_columns, columns_order, num_parallel_workers, python_multiprocessing], _ = \ | |||
| [input_columns, _, output_columns, columns_order, num_parallel_workers, python_multiprocessing, cache], _ = \ | |||
| parse_user_args(method, *args, **kwargs) | |||
| nreq_param_columns = ['input_columns', 'output_columns'] | |||
| @@ -516,6 +538,8 @@ def check_map(method): | |||
| if num_parallel_workers is not None: | |||
| check_num_parallel_workers(num_parallel_workers) | |||
| type_check(python_multiprocessing, (bool,), "python_multiprocessing") | |||
| if cache is not None: | |||
| type_check(cache, (cache_client.DatasetCache,), "cache") | |||
| for param_name, param in zip(nreq_param_columns, [input_columns, output_columns]): | |||
| if param is not None: | |||
| @@ -720,7 +744,7 @@ def check_add_column(method): | |||
| def check_cluedataset(method): | |||
| """A wrapper that wrap a parameter checker to the original Dataset(CLUEDataset).""" | |||
| """A wrapper that wraps a parameter checker to the original Dataset(CLUEDataset).""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| @@ -750,7 +774,7 @@ def check_cluedataset(method): | |||
| def check_textfiledataset(method): | |||
| """A wrapper that wrap a parameter checker to the original Dataset(TextFileDataset).""" | |||
| """A wrapper that wraps a parameter checker to the original Dataset(TextFileDataset).""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| @@ -823,7 +847,7 @@ def check_gnn_graphdata(method): | |||
| def check_gnn_get_all_nodes(method): | |||
| """A wrapper that wrap a parameter checker to the GNN `get_all_nodes` function.""" | |||
| """A wrapper that wraps a parameter checker to the GNN `get_all_nodes` function.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| @@ -836,7 +860,7 @@ def check_gnn_get_all_nodes(method): | |||
| def check_gnn_get_all_edges(method): | |||
| """A wrapper that wrap a parameter checker to the GNN `get_all_edges` function.""" | |||
| """A wrapper that wraps a parameter checker to the GNN `get_all_edges` function.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| @@ -849,7 +873,7 @@ def check_gnn_get_all_edges(method): | |||
| def check_gnn_get_nodes_from_edges(method): | |||
| """A wrapper that wrap a parameter checker to the GNN `get_nodes_from_edges` function.""" | |||
| """A wrapper that wraps a parameter checker to the GNN `get_nodes_from_edges` function.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| @@ -862,7 +886,7 @@ def check_gnn_get_nodes_from_edges(method): | |||
| def check_gnn_get_all_neighbors(method): | |||
| """A wrapper that wrap a parameter checker to the GNN `get_all_neighbors` function.""" | |||
| """A wrapper that wraps a parameter checker to the GNN `get_all_neighbors` function.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| @@ -877,7 +901,7 @@ def check_gnn_get_all_neighbors(method): | |||
| def check_gnn_get_sampled_neighbors(method): | |||
| """A wrapper that wrap a parameter checker to the GNN `get_sampled_neighbors` function.""" | |||
| """A wrapper that wraps a parameter checker to the GNN `get_sampled_neighbors` function.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| @@ -905,7 +929,7 @@ def check_gnn_get_sampled_neighbors(method): | |||
| def check_gnn_get_neg_sampled_neighbors(method): | |||
| """A wrapper that wrap a parameter checker to the GNN `get_neg_sampled_neighbors` function.""" | |||
| """A wrapper that wraps a parameter checker to the GNN `get_neg_sampled_neighbors` function.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| @@ -921,7 +945,7 @@ def check_gnn_get_neg_sampled_neighbors(method): | |||
| def check_gnn_random_walk(method): | |||
| """A wrapper that wrap a parameter checker to the GNN `random_walk` function.""" | |||
| """A wrapper that wraps a parameter checker to the GNN `random_walk` function.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| @@ -968,7 +992,7 @@ def check_aligned_list(param, param_name, member_type): | |||
| def check_gnn_get_node_feature(method): | |||
| """A wrapper that wrap a parameter checker to the GNN `get_node_feature` function.""" | |||
| """A wrapper that wraps a parameter checker to the GNN `get_node_feature` function.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| @@ -1012,7 +1036,7 @@ def check_gnn_get_edge_feature(method): | |||
| def check_numpyslicesdataset(method): | |||
| """A wrapper that wrap a parameter checker to the original Dataset(NumpySlicesDataset).""" | |||
| """A wrapper that wraps a parameter checker to the original Dataset(NumpySlicesDataset).""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| @@ -39,7 +39,7 @@ def check_unique_list_of_words(words, arg_name): | |||
| def check_lookup(method): | |||
| """A wrapper that wrap a parameter checker to the original function.""" | |||
| """A wrapper that wraps a parameter checker to the original function.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| @@ -56,7 +56,7 @@ def check_lookup(method): | |||
| def check_from_file(method): | |||
| """A wrapper that wrap a parameter checker to the original function.""" | |||
| """A wrapper that wraps a parameter checker to the original function.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| @@ -74,7 +74,7 @@ def check_from_file(method): | |||
| def check_from_list(method): | |||
| """A wrapper that wrap a parameter checker to the original function.""" | |||
| """A wrapper that wraps a parameter checker to the original function.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| @@ -97,7 +97,7 @@ def check_from_list(method): | |||
| def check_from_dict(method): | |||
| """A wrapper that wrap a parameter checker to the original function.""" | |||
| """A wrapper that wraps a parameter checker to the original function.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| @@ -285,7 +285,7 @@ def check_bert_tokenizer(method): | |||
| def check_from_dataset(method): | |||
| """A wrapper that wrap a parameter checker to the original function.""" | |||
| """A wrapper that wraps a parameter checker to the original function.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| @@ -328,7 +328,7 @@ def check_from_dataset(method): | |||
| def check_ngram(method): | |||
| """A wrapper that wrap a parameter checker to the original function.""" | |||
| """A wrapper that wraps a parameter checker to the original function.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| @@ -114,7 +114,7 @@ def check_erasing_value(value): | |||
| def check_crop(method): | |||
| """A wrapper that wrap a parameter checker to the original function(crop operation).""" | |||
| """A wrapper that wraps a parameter checker to the original function(crop operation).""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| @@ -127,7 +127,7 @@ def check_crop(method): | |||
| def check_resize_interpolation(method): | |||
| """A wrapper that wrap a parameter checker to the original function(resize interpolation operation).""" | |||
| """A wrapper that wraps a parameter checker to the original function(resize interpolation operation).""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| @@ -142,7 +142,7 @@ def check_resize_interpolation(method): | |||
| def check_resize(method): | |||
| """A wrapper that wrap a parameter checker to the original function(resize operation).""" | |||
| """A wrapper that wraps a parameter checker to the original function(resize operation).""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| @@ -155,7 +155,7 @@ def check_resize(method): | |||
| def check_random_resize_crop(method): | |||
| """A wrapper that wrap a parameter checker to the original function(random resize crop operation).""" | |||
| """A wrapper that wraps a parameter checker to the original function(random resize crop operation).""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| @@ -178,7 +178,7 @@ def check_random_resize_crop(method): | |||
| def check_prob(method): | |||
| """A wrapper that wrap a parameter checker(check the probability) to the original function.""" | |||
| """A wrapper that wraps a parameter checker(check the probability) to the original function.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| @@ -192,7 +192,7 @@ def check_prob(method): | |||
| def check_normalize_c(method): | |||
| """A wrapper that wrap a parameter checker to the original function(normalize operation written in C++).""" | |||
| """A wrapper that wraps a parameter checker to the original function(normalize operation written in C++).""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| @@ -205,7 +205,7 @@ def check_normalize_c(method): | |||
| def check_normalize_py(method): | |||
| """A wrapper that wrap a parameter checker to the original function(normalize operation written in Python).""" | |||
| """A wrapper that wraps a parameter checker to the original function(normalize operation written in Python).""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| @@ -738,7 +738,7 @@ TEST_F(MindDataTestPipeline, TestProjectMap) { | |||
| EXPECT_TRUE(ds != nullptr); | |||
| // Create a Project operation on ds | |||
| std::vector<std::string> column_project = {"label"}; | |||
| std::vector<std::string> column_project = {"image"}; | |||
| ds = ds->Project(column_project); | |||
| EXPECT_TRUE(ds != nullptr); | |||
| @@ -0,0 +1,579 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <string> | |||
| #include "dataset/core/client.h" | |||
| #include "dataset/engine/cache/cache_client.h" | |||
| #include "dataset/engine/execution_tree.h" | |||
| #include "dataset/engine/datasetops/cache_op.h" | |||
| #include "dataset/engine/datasetops/cache_lookup_op.h" | |||
| #include "dataset/engine/datasetops/cache_merge_op.h" | |||
| #include "dataset/engine/datasetops/source/image_folder_op.h" | |||
| #include "common/common.h" | |||
| #include "gtest/gtest.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "dataset/util/storage_container.h" // lint !e322 | |||
| #include "dataset/engine/datasetops/source/random_data_op.h" | |||
| #include "dataset/engine/data_schema.h" | |||
| using namespace mindspore::dataset; | |||
| using mindspore::LogStream; | |||
| using mindspore::dataset::CacheClient; | |||
| using mindspore::dataset::TaskGroup; | |||
| using mindspore::ExceptionType::NoExceptionType; | |||
| using mindspore::MsLogLevel::INFO; | |||
| class MindDataTestCacheOp : public UT::DatasetOpTesting { | |||
| public: | |||
| void SetUp() override { | |||
| DatasetOpTesting::SetUp(); | |||
| GlobalInit(); | |||
| } | |||
| }; | |||
| TEST_F(MindDataTestCacheOp, TestCacheServer) { | |||
| Status rc; | |||
| CacheClient myClient(1, 0, true); // use arbitrary session of 1, size of 0, spilling is true | |||
| // cksum value of 1 for CreateCache here...normally you do not directly create a cache and the cksum arg is generated. | |||
| rc = myClient.CreateCache(1, true); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| std::cout << myClient << std::endl; | |||
| // Create a schema using the C api's | |||
| int32_t rank = 0; // not used | |||
| std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>(); | |||
| // 2 columns. First column is an "image" 640,480,3 | |||
| TensorShape c1Shape({640, 480, 3}); | |||
| ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible, | |||
| rank, // not used | |||
| &c1Shape); | |||
| // Column 2 will just be a scalar label number | |||
| TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor | |||
| ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape); | |||
| testSchema->AddColumn(c1); | |||
| testSchema->AddColumn(c2); | |||
| std::unordered_map<std::string, int32_t> map; | |||
| rc = testSchema->GetColumnNameMap(&map); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // Test the CacheSchema api | |||
| rc = myClient.CacheSchema(map); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // Create a tensor, take a snapshot and restore it back, and compare. | |||
| std::shared_ptr<Tensor> t = std::make_shared<Tensor>(TensorShape({2, 3}), DataType(DataType::DE_UINT64)); | |||
| t->SetItemAt<uint64_t>({0, 0}, 1); | |||
| t->SetItemAt<uint64_t>({0, 1}, 2); | |||
| t->SetItemAt<uint64_t>({0, 2}, 3); | |||
| t->SetItemAt<uint64_t>({1, 0}, 4); | |||
| t->SetItemAt<uint64_t>({1, 1}, 5); | |||
| t->SetItemAt<uint64_t>({1, 2}, 6); | |||
| std::cout << *t << std::endl; | |||
| TensorTable tbl; | |||
| TensorRow row; | |||
| row.push_back(t); | |||
| int64_t row_id; | |||
| rc = myClient.WriteRow(row, &row_id); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // Switch off build phase. | |||
| rc = myClient.BuildPhaseDone(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // Now restore from cache. | |||
| row.clear(); | |||
| rc = myClient.GetRows({row_id}, &tbl); | |||
| row = tbl.front(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| auto r = row.front(); | |||
| std::cout << *r << std::endl; | |||
| // Compare | |||
| bool cmp = (*t == *r); | |||
| EXPECT_TRUE(cmp); | |||
| // Get back the schema and verify | |||
| std::unordered_map<std::string, int32_t> map_out; | |||
| rc = myClient.FetchSchema(&map_out); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| cmp = (map_out == map); | |||
| EXPECT_TRUE(cmp); | |||
| // Test Purge and Destroy | |||
| rc = myClient.PurgeCache(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myClient.DestroyCache(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| } | |||
| TEST_F(MindDataTestCacheOp, TestConcurrencyRequest) { | |||
| // Clear the rc of the master thread if any | |||
| (void)TaskManager::GetMasterThreadRc(); | |||
| TaskGroup vg; | |||
| Status rc; | |||
| CacheClient myClient(1, 1, true); // use arbitrary session of 1, size 1, spilling is true | |||
| // cksum value of 1 for CreateCache here...normally you do not directly create a cache and the cksum arg is generated. | |||
| rc = myClient.CreateCache(1, true); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| std::cout << myClient << std::endl; | |||
| std::shared_ptr<Tensor> t = std::make_shared<Tensor>(TensorShape({2, 3}), DataType(DataType::DE_UINT64)); | |||
| t->SetItemAt<uint64_t>({0, 0}, 1); | |||
| t->SetItemAt<uint64_t>({0, 1}, 2); | |||
| t->SetItemAt<uint64_t>({0, 2}, 3); | |||
| t->SetItemAt<uint64_t>({1, 0}, 4); | |||
| t->SetItemAt<uint64_t>({1, 1}, 5); | |||
| t->SetItemAt<uint64_t>({1, 2}, 6); | |||
| TensorTable tbl; | |||
| TensorRow row; | |||
| row.push_back(t); | |||
| // Cache tensor row t 5000 times using 10 threads. | |||
| for (auto k = 0; k < 10; ++k) { | |||
| Status vg_rc = vg.CreateAsyncTask("Test agent", [&myClient, &row]() -> Status { | |||
| TaskManager::FindMe()->Post(); | |||
| for (auto i = 0; i < 500; i++) { | |||
| RETURN_IF_NOT_OK(myClient.WriteRow(row)); | |||
| } | |||
| return Status::OK(); | |||
| }); | |||
| EXPECT_TRUE(vg_rc.IsOk()); | |||
| } | |||
| ASSERT_TRUE(vg.join_all().IsOk()); | |||
| ASSERT_TRUE(vg.GetTaskErrorIfAny().IsOk()); | |||
| rc = myClient.BuildPhaseDone(); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // Get statistics from the server. | |||
| CacheClient::ServiceStat stat{}; | |||
| rc = myClient.GetStat(&stat); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| std::cout << stat.min_row_id << ":" << stat.max_row_id << ":" << stat.num_mem_cached << ":" << stat.num_disk_cached | |||
| << "\n"; | |||
| // Expect there are 5000 rows there. | |||
| EXPECT_EQ(5000, stat.max_row_id - stat.min_row_id + 1); | |||
| // Get them all back using row id and compare with tensor t. | |||
| for (auto i = stat.min_row_id; i <= stat.max_row_id; ++i) { | |||
| tbl.clear(); | |||
| row.clear(); | |||
| rc = myClient.GetRows({i}, &tbl); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| row = tbl.front(); | |||
| auto r = row.front(); | |||
| bool cmp = (*t == *r); | |||
| EXPECT_TRUE(cmp); | |||
| } | |||
| rc = myClient.DestroyCache(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| } | |||
| // Simple test with a repeated cache op over random data producer | |||
| // | |||
| // RepeatOp | |||
| // | | |||
| // CacheOp | |||
| // | | |||
| // RandomDataOp | |||
| // | |||
| TEST_F(MindDataTestCacheOp, TestRandomDataCache1) { | |||
| Status rc; | |||
| int32_t rank = 0; // not used | |||
| MS_LOG(INFO) << "UT test TestRandomDataCache1"; | |||
| // Start with an empty execution tree | |||
| auto myTree = std::make_shared<ExecutionTree>(); | |||
| // Create a schema using the C api's | |||
| std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>(); | |||
| // 2 columns. First column is an "image" 640,480,3 | |||
| TensorShape c1Shape({640, 480, 3}); | |||
| ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible, | |||
| rank, // not used | |||
| &c1Shape); | |||
| // Column 2 will just be a scalar label number | |||
| TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor | |||
| ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape); | |||
| testSchema->AddColumn(c1); | |||
| testSchema->AddColumn(c2); | |||
| // RandomDataOp | |||
| std::shared_ptr<RandomDataOp> myRandomDataOp; | |||
| rc = RandomDataOp::Builder() | |||
| .SetRowsPerBuffer(4) | |||
| .SetNumWorkers(4) | |||
| .SetDataSchema(std::move(testSchema)) | |||
| .SetTotalRows(50) // 50 samples for now | |||
| .Build(&myRandomDataOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssociateNode(myRandomDataOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // CacheOp | |||
| // size of 0, spilling is true | |||
| std::shared_ptr<CacheClient> myClient = std::make_shared<CacheClient>(1, 0, true); | |||
| std::shared_ptr<CacheOp> myCacheOp; | |||
| int64_t num_samples = 0; | |||
| int64_t start_index = 0; | |||
| auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | |||
| rc = CacheOp::Builder() | |||
| .SetNumWorkers(5) | |||
| .SetClient(myClient) | |||
| .SetRowsPerBuffer(4) | |||
| .SetSampler(std::move(seq_sampler)) | |||
| .Build(&myCacheOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssociateNode(myCacheOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // RepeatOp | |||
| uint32_t numRepeats = 4; | |||
| std::shared_ptr<RepeatOp> myRepeatOp; | |||
| rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssociateNode(myRepeatOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // Assign tree relations and root | |||
| rc = myRepeatOp->AddChild(myCacheOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myCacheOp->AddChild(myRandomDataOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssignRoot(myRepeatOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| MS_LOG(INFO) << "Launching tree and begin iteration"; | |||
| rc = myTree->Prepare(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // quick check to see what tree looks like | |||
| std::ostringstream ss; | |||
| ss << *myTree; // some funny const error if I try to write directly to ms log stream | |||
| MS_LOG(INFO) << "Here's the tree:\n" << ss.str(); | |||
| std::cout << *myClient << std::endl; | |||
| rc = myTree->Launch(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // Start the loop of reading tensors from our pipeline | |||
| DatasetIterator dI(myTree); | |||
| TensorRow tensorList; | |||
| rc = dI.FetchNextTensorRow(&tensorList); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| int rowCount = 0; | |||
| while (!tensorList.empty()) { | |||
| // Don't display these rows, just count them | |||
| MS_LOG(INFO) << "Row fetched #: " << rowCount; | |||
| rc = dI.FetchNextTensorRow(&tensorList); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rowCount++; | |||
| } | |||
| ASSERT_EQ(rowCount, 200); | |||
| rc = myClient->DestroyCache(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| } | |||
| //// Simple test with a repeated cache op over random data producer. | |||
| //// This one will exceed memory and require a spill. | |||
| //// | |||
| //// RepeatOp | |||
| //// | | |||
| //// CacheOp | |||
| //// | | |||
| //// RandomDataOp | |||
| //// | |||
| TEST_F(MindDataTestCacheOp, TestRandomDataCacheSpill) { | |||
| Status rc; | |||
| int32_t rank = 0; // not used | |||
| MS_LOG(INFO) << "UT test TestRandomDataCacheSpill"; | |||
| // Start with an empty execution tree | |||
| auto myTree = std::make_shared<ExecutionTree>(); | |||
| // Create a schema using the C api's | |||
| std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>(); | |||
| // 2 columns. First column is an "image" 640,480,3 | |||
| TensorShape c1Shape({640, 480, 3}); | |||
| ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible, | |||
| rank, // not used | |||
| &c1Shape); | |||
| // Column 2 will just be a scalar label number | |||
| TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor | |||
| ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape); | |||
| testSchema->AddColumn(c1); | |||
| testSchema->AddColumn(c2); | |||
| // RandomDataOp | |||
| std::shared_ptr<RandomDataOp> myRandomDataOp; | |||
| rc = RandomDataOp::Builder() | |||
| .SetRowsPerBuffer(2) | |||
| .SetNumWorkers(4) | |||
| .SetDataSchema(std::move(testSchema)) | |||
| .SetTotalRows(10) | |||
| .Build(&myRandomDataOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssociateNode(myRandomDataOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // CacheOp | |||
| int64_t num_samples = 0; | |||
| int64_t start_index = 0; | |||
| auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | |||
| std::shared_ptr<CacheClient> myClient = std::make_shared<CacheClient>(1, 4, true); | |||
| std::shared_ptr<CacheOp> myCacheOp; | |||
| rc = CacheOp::Builder() | |||
| .SetNumWorkers(4) | |||
| .SetClient(myClient) | |||
| .SetRowsPerBuffer(3) | |||
| .SetSampler(std::move(seq_sampler)) | |||
| .Build(&myCacheOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssociateNode(myCacheOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // RepeatOp | |||
| uint32_t numRepeats = 4; | |||
| std::shared_ptr<RepeatOp> myRepeatOp; | |||
| rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssociateNode(myRepeatOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // Assign tree relations and root | |||
| rc = myRepeatOp->AddChild(myCacheOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myCacheOp->AddChild(myRandomDataOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssignRoot(myRepeatOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| MS_LOG(INFO) << "Launching tree and begin iteration"; | |||
| rc = myTree->Prepare(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| std::cout << *myClient << std::endl; | |||
| rc = myTree->Launch(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // Start the loop of reading tensors from our pipeline | |||
| DatasetIterator dI(myTree); | |||
| TensorRow tensorList; | |||
| rc = dI.FetchNextTensorRow(&tensorList); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| int rowCount = 0; | |||
| while (!tensorList.empty()) { | |||
| // Don't display these rows, just count them | |||
| MS_LOG(INFO) << "Row fetched #: " << rowCount; | |||
| rc = dI.FetchNextTensorRow(&tensorList); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rowCount++; | |||
| } | |||
| ASSERT_EQ(rowCount, 40); | |||
| rc = myClient->DestroyCache(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| } | |||
| TEST_F(MindDataTestCacheOp, TestImageFolderCacheMerge) { | |||
| Status rc; | |||
| int64_t num_samples = 0; | |||
| int64_t start_index = 0; | |||
| auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | |||
| std::shared_ptr<CacheClient> myClient = std::make_shared<CacheClient>(1, 0, true); | |||
| std::shared_ptr<CacheMergeOp> myMergeOp; | |||
| rc = CacheMergeOp::Builder().SetNumWorkers(3).SetOpConnectorSize(3).SetNumCleaner(2).SetClient(myClient).Build( | |||
| &myMergeOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| std::shared_ptr<CacheLookupOp> myLookupOp; | |||
| rc = CacheLookupOp::Builder() | |||
| .SetNumWorkers(3) | |||
| .SetOpConnectorSize(3) | |||
| .SetClient(myClient) | |||
| .SetSampler(seq_sampler) | |||
| .Build(&myLookupOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| std::shared_ptr<ImageFolderOp> so; | |||
| ImageFolderOp::Builder builder; | |||
| builder.SetSampler(myLookupOp) | |||
| .SetOpConnectorSize(3) | |||
| .SetNumWorkers(3) | |||
| .SetRowsPerBuffer(2) | |||
| .SetExtensions({".jpg", ".JPEG"}) | |||
| .SetRecursive(true) | |||
| .SetImageFolderDir(datasets_root_path_ + "/testPK/data"); | |||
| rc = builder.Build(&so); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // RepeatOp | |||
| uint32_t numRepeats = 4; | |||
| std::shared_ptr<RepeatOp> myRepeatOp; | |||
| rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| auto myTree = std::make_shared<ExecutionTree>(); | |||
| rc = myTree->AssociateNode(so); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssociateNode(myLookupOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssociateNode(myMergeOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssociateNode(myRepeatOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssignRoot(myRepeatOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myRepeatOp->AddChild(myMergeOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myMergeOp->AddChild(myLookupOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myMergeOp->AddChild(so); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myTree->Prepare(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myTree->Launch(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // Start the loop of reading tensors from our pipeline | |||
| DatasetIterator dI(myTree); | |||
| TensorRow tensorList; | |||
| rc = dI.FetchNextTensorRow(&tensorList); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| int rowCount = 0; | |||
| while (!tensorList.empty()) { | |||
| rc = dI.FetchNextTensorRow(&tensorList); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| if (rc.IsError()) { | |||
| std::cout << rc << std::endl; | |||
| break; | |||
| } | |||
| rowCount++; | |||
| } | |||
| ASSERT_EQ(rowCount, 176); | |||
| std::cout << "Row count : " << rowCount << std::endl; | |||
| rc = myClient->DestroyCache(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| } | |||
| //// Simple test with a repeated cache op over random data producer. | |||
| //// The difference in this one is that you do not add the sampler to the cache op directly. | |||
| //// Instead, the sampler is added as part of the leaf op construction. Then, the prepare | |||
| //// phase will pull this up from the leaf and into the cache. | |||
| //// It removes the sampler from the leaf op, which doesn't make sense there anyway for | |||
| //// the RandomDataOp which doesn't support sampling without a cache. | |||
| //// | |||
| //// RepeatOp | |||
| //// | | |||
| //// CacheOp | |||
| //// | | |||
| //// RandomDataOp | |||
| //// | |||
| TEST_F(MindDataTestCacheOp, TestCacheInheritSampler) { | |||
| Status rc; | |||
| int32_t rank = 0; // not used | |||
| MS_LOG(INFO) << "UT test TestCacheInheritSampler"; | |||
| int64_t num_samples = 0; | |||
| int64_t start_index = 0; | |||
| auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | |||
| // Start with an empty execution tree | |||
| auto myTree = std::make_shared<ExecutionTree>(); | |||
| // Create a schema using the C api's | |||
| std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>(); | |||
| // 2 columns. First column is an "image" 640,480,3 | |||
| TensorShape c1Shape({640, 480, 3}); | |||
| ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible, | |||
| rank, // not used | |||
| &c1Shape); | |||
| // Column 2 will just be a scalar label number | |||
| TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor | |||
| ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape); | |||
| testSchema->AddColumn(c1); | |||
| testSchema->AddColumn(c2); | |||
| // RandomDataOp | |||
| std::shared_ptr<RandomDataOp> myRandomDataOp; | |||
| rc = RandomDataOp::Builder() | |||
| .SetRowsPerBuffer(2) | |||
| .SetNumWorkers(4) | |||
| .SetDataSchema(std::move(testSchema)) | |||
| .SetTotalRows(10) | |||
| .SetSampler(std::move(seq_sampler)) | |||
| .Build(&myRandomDataOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssociateNode(myRandomDataOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // CacheOp | |||
| std::shared_ptr<CacheClient> myClient = std::make_shared<CacheClient>(1, 4, true); | |||
| std::shared_ptr<CacheOp> myCacheOp; | |||
| rc = CacheOp::Builder().SetNumWorkers(4).SetClient(myClient).SetRowsPerBuffer(3).Build(&myCacheOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssociateNode(myCacheOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // RepeatOp | |||
| uint32_t numRepeats = 4; | |||
| std::shared_ptr<RepeatOp> myRepeatOp; | |||
| rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssociateNode(myRepeatOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // Assign tree relations and root | |||
| rc = myRepeatOp->AddChild(myCacheOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myCacheOp->AddChild(myRandomDataOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssignRoot(myRepeatOp); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| MS_LOG(INFO) << "Launching tree and begin iteration"; | |||
| rc = myTree->Prepare(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| std::cout << *myClient << std::endl; | |||
| rc = myTree->Launch(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // Start the loop of reading tensors from our pipeline | |||
| DatasetIterator dI(myTree); | |||
| TensorRow tensorList; | |||
| rc = dI.FetchNextTensorRow(&tensorList); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| int rowCount = 0; | |||
| while (!tensorList.empty()) { | |||
| // Don't display these rows, just count them | |||
| MS_LOG(INFO) << "Row fetched #: " << rowCount; | |||
| rc = dI.FetchNextTensorRow(&tensorList); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rowCount++; | |||
| } | |||
| ASSERT_EQ(rowCount, 40); | |||
| rc = myClient->DestroyCache(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| } | |||
| @@ -0,0 +1,157 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Testing cache operator with mappable datasets | |||
| """ | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.c_transforms as c_vision | |||
| from mindspore import log as logger | |||
| from util import save_and_check_md5 | |||
| DATA_DIR = "../data/dataset/testImageNetData/train/" | |||
| GENERATE_GOLDEN = False | |||
| def test_cache_map_basic1(): | |||
| """ | |||
| Test mappable leaf with cache op right over the leaf | |||
| Repeat | |||
| | | |||
| Map(decode) | |||
| | | |||
| Cache | |||
| | | |||
| ImageFolder | |||
| """ | |||
| logger.info("Test cache map basic 1") | |||
| some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR, cache=some_cache) | |||
| decode_op = c_vision.Decode() | |||
| ds1 = ds1.map(input_columns=["image"], operations=decode_op) | |||
| ds1 = ds1.repeat(4) | |||
| filename = "cache_map_01_result.npz" | |||
| save_and_check_md5(ds1, filename, generate_golden=GENERATE_GOLDEN) | |||
| logger.info("test_cache_map_basic1 Ended.\n") | |||
| def test_cache_map_basic2(): | |||
| """ | |||
| Test mappable leaf with the cache op later in the tree above the map(decode) | |||
| Repeat | |||
| | | |||
| Cache | |||
| | | |||
| Map(decode) | |||
| | | |||
| ImageFolder | |||
| """ | |||
| logger.info("Test cache map basic 2") | |||
| some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR) | |||
| decode_op = c_vision.Decode() | |||
| ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) | |||
| ds1 = ds1.repeat(4) | |||
| filename = "cache_map_02_result.npz" | |||
| save_and_check_md5(ds1, filename, generate_golden=GENERATE_GOLDEN) | |||
| logger.info("test_cache_map_basic2 Ended.\n") | |||
| def test_cache_map_basic3(): | |||
| """ | |||
| Test a repeat under mappable cache | |||
| Cache | |||
| | | |||
| Map(decode) | |||
| | | |||
| Repeat | |||
| | | |||
| ImageFolder | |||
| """ | |||
| logger.info("Test cache basic 3") | |||
| some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR) | |||
| decode_op = c_vision.Decode() | |||
| ds1 = ds1.repeat(4) | |||
| ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) | |||
| num_iter = 0 | |||
| for _ in ds1.create_dict_iterator(): | |||
| num_iter += 1 | |||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | |||
| assert num_iter == 8 | |||
| logger.info('test_cache_basic3 Ended.\n') | |||
| def test_cache_map_failure1(): | |||
| """ | |||
| Test nested cache (failure) | |||
| Repeat | |||
| | | |||
| Cache | |||
| | | |||
| Map(decode) | |||
| | | |||
| Cache | |||
| | | |||
| ImageFolder | |||
| """ | |||
| logger.info("Test cache failure 1") | |||
| some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR, cache=some_cache) | |||
| decode_op = c_vision.Decode() | |||
| ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) | |||
| ds1 = ds1.repeat(4) | |||
| try: | |||
| num_iter = 0 | |||
| for _ in ds1.create_dict_iterator(): | |||
| num_iter += 1 | |||
| except RuntimeError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "Nested cache operations is not supported!" in str(e) | |||
| assert num_iter == 0 | |||
| logger.info('test_cache_failure1 Ended.\n') | |||
| if __name__ == '__main__': | |||
| test_cache_map_basic1() | |||
| test_cache_map_basic2() | |||
| test_cache_map_basic3() | |||
| test_cache_map_failure1() | |||
| @@ -0,0 +1,429 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Testing cache operator with non-mappable datasets | |||
| """ | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.c_transforms as c_vision | |||
| from mindspore import log as logger | |||
| DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | |||
| SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | |||
| GENERATE_GOLDEN = False | |||
| def test_cache_nomap_basic1(): | |||
| """ | |||
| A random dataset (a non mappable dataset) with a cache over it just after the leaf | |||
| """ | |||
| logger.info("Test cache nomap basic 1") | |||
| schema = ds.Schema() | |||
| schema.add_column('image', de_type=mstype.uint8, | |||
| shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image) | |||
| schema.add_column('label', de_type=mstype.uint8, shape=[1]) | |||
| # create a cache. arbitrary session_id for now | |||
| some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) | |||
| # User-created sampler here | |||
| ds1 = ds.RandomDataset(schema=schema, total_rows=10, num_parallel_workers=4, cache=some_cache) | |||
| ds1 = ds1.repeat(4) | |||
| num_iter = 0 | |||
| for data in ds1.create_dict_iterator(): | |||
| logger.info("printing the label: {}".format(data["label"])) | |||
| num_iter += 1 | |||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | |||
| assert num_iter == 40 | |||
| logger.info("test_cache_nomap_basic1 Ended.\n") | |||
| def test_cache_nomap_basic2(): | |||
| """ | |||
| A random dataset (a non mappable dataset) with a cache over it just after the leaf | |||
| """ | |||
| logger.info("Test cache nomap basic 2") | |||
| schema = ds.Schema() | |||
| schema.add_column('image', de_type=mstype.uint8, | |||
| shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image) | |||
| schema.add_column('label', de_type=mstype.uint8, shape=[1]) | |||
| # create a cache. arbitrary session_id for now | |||
| some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) | |||
| # sampler arg not given directly, however any of these args will auto-generate an appropriate sampler: | |||
| # num_samples, shuffle, num_shards, shard_id | |||
| # In this case, the presence of num_samples chooses a sampler. | |||
| ds1 = ds.RandomDataset(schema=schema, total_rows=20, num_samples=20, num_parallel_workers=4, cache=some_cache) | |||
| ds1 = ds1.repeat(2) | |||
| num_iter = 0 | |||
| for data in ds1.create_dict_iterator(): | |||
| logger.info("printing the label: {}".format(data["label"])) | |||
| num_iter += 1 | |||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | |||
| assert num_iter == 40 | |||
| logger.info("test_cache_nomap_basic2 Ended.\n") | |||
| def test_cache_nomap_basic3(): | |||
| """ | |||
| A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf | |||
| Repeat | |||
| | | |||
| Map(decode) | |||
| | | |||
| Cache | |||
| | | |||
| TFReader | |||
| """ | |||
| logger.info("Test cache nomap basic 3") | |||
| some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache) | |||
| decode_op = c_vision.Decode() | |||
| ds1 = ds1.map(input_columns=["image"], operations=decode_op) | |||
| ds1 = ds1.repeat(4) | |||
| num_iter = 0 | |||
| for _ in ds1.create_dict_iterator(): | |||
| num_iter += 1 | |||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | |||
| assert num_iter == 12 | |||
| logger.info("test_cache_nomap_basic3 Ended.\n") | |||
| def test_cache_nomap_basic4(): | |||
| """ | |||
| A TF reader dataset (a non mappable dataset) with a map decode and cache after it | |||
| Since a global shuffle is used for the tf reader, it will inject a shuffle op over the tf. | |||
| But, if there's a cache later, that shuffle becomes invalid and should be removed. | |||
| Repeat | |||
| | | |||
| Cache | |||
| | | |||
| Map(decode) | |||
| | | |||
| TFReader | |||
| """ | |||
| logger.info("Test cache nomap basic 4") | |||
| # This dataset has 3 records in it only | |||
| some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) | |||
| # With shuffle not being set, TF defaults to a "global" shuffle when there is no cache | |||
| # in the picture. This causes a shuffle-injection over the TF. For clarify, this test will | |||
| # explicitly give the global option, even though it's the default in python. | |||
| # But, when caching is added in the ascendent tree above TF, we do global shuffling | |||
| # through the sampler over the cache, not by the shuffle op. In that case, tree prepare | |||
| # will remove the shuffle op that got injected by the initial tree creation. | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=ds.Shuffle.GLOBAL) | |||
| decode_op = c_vision.Decode() | |||
| ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) | |||
| ds1 = ds1.repeat(4) | |||
| num_iter = 0 | |||
| for _ in ds1.create_dict_iterator(): | |||
| num_iter += 1 | |||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | |||
| assert num_iter == 12 | |||
| logger.info("test_cache_nomap_basic4 Ended.\n") | |||
| def test_cache_nomap_basic5(): | |||
| """ | |||
| A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf | |||
| Same as test 3, but this one does not have shuffle arg, causing tf to default to global | |||
| shuffle which attempts to inject a shuffle operator. However, since there is a cache | |||
| we do not need global shuffle, so the shuffle will not be built. It ends up being | |||
| identical to test basic 3, however we arrive at the same tree in different codepaths | |||
| (if there was no cache, then the shuffle IS built) | |||
| Repeat | |||
| | | |||
| Map(decode) | |||
| | | |||
| Cache | |||
| | | |||
| TFReader | |||
| """ | |||
| logger.info("Test cache nomap basic 5") | |||
| # This dataset has 3 records in it only | |||
| some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], cache=some_cache) | |||
| decode_op = c_vision.Decode() | |||
| ds1 = ds1.map(input_columns=["image"], operations=decode_op) | |||
| ds1 = ds1.repeat(4) | |||
| num_iter = 0 | |||
| for _ in ds1.create_dict_iterator(): | |||
| num_iter += 1 | |||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | |||
| assert num_iter == 12 | |||
| logger.info("test_cache_nomap_basic5 Ended.\n") | |||
| def test_cache_nomap_basic6(): | |||
| """ | |||
| A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf | |||
| In this one, the tf dataset will be given sharding configuration, however since a cache is | |||
| used, the tree prepare should undo the sharding configuration and instead, a distributed | |||
| sampler will be chosen with the same shard config. | |||
| Repeat | |||
| | | |||
| Map(decode) | |||
| | | |||
| Cache | |||
| | | |||
| TFReader | |||
| """ | |||
| logger.info("Test cache nomap basic 6") | |||
| # This dataset has 3 records in it only | |||
| some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) | |||
| # With only 3 records shard into 3, we expect only 1 record returned for this shard | |||
| # However, the sharding will be done by the sampler, not by the tf record leaf node | |||
| # In this case, it is a row-based sharding, not the file-based sharding that would happen if | |||
| # there was not any cache. | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_shards=3, shard_id=1, cache=some_cache) | |||
| decode_op = c_vision.Decode() | |||
| ds1 = ds1.map(input_columns=["image"], operations=decode_op) | |||
| ds1 = ds1.repeat(4) | |||
| num_iter = 0 | |||
| for _ in ds1.create_dict_iterator(): | |||
| num_iter += 1 | |||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | |||
| assert num_iter == 4 | |||
| logger.info("test_cache_nomap_basic6 Ended.\n") | |||
| def test_cache_nomap_basic7(): | |||
| """ | |||
| A TF reader dataset (a non mappable dataset) that uses global shuffle, and is cached followed by | |||
| map. | |||
| In this one, the tf dataset with global shuffle might want to inject a shuffle op over top of the | |||
| tf reader, but since a cache is given, it will choose not to. | |||
| Repeat | |||
| | | |||
| Map(decode) | |||
| | | |||
| cache | |||
| | | |||
| TFReader | |||
| """ | |||
| logger.info("Test cache nomap basic 7") | |||
| # This dataset has 3 records in it only | |||
| some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=ds.Shuffle.GLOBAL, cache=some_cache) | |||
| decode_op = c_vision.Decode() | |||
| ds1 = ds1.map(input_columns=["image"], operations=decode_op) | |||
| ds1 = ds1.repeat(4) | |||
| num_iter = 0 | |||
| for _ in ds1.create_dict_iterator(): | |||
| num_iter += 1 | |||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | |||
| assert num_iter == 12 | |||
| logger.info("test_cache_nomap_basic7 Ended.\n") | |||
| def test_cache_nomap_allowed_share1(): | |||
| """ | |||
| It is allowed to share the cache between the following two trees: | |||
| Repeat Shuffle | |||
| | | | |||
| Cache Cache | |||
| | | | |||
| TFReader TFReader | |||
| """ | |||
| logger.info("Test cache nomap allowed share 1") | |||
| ds.config.set_seed(1) | |||
| # This dataset has 3 records in it only | |||
| some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache) | |||
| ds1 = ds1.repeat(4) | |||
| ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache) | |||
| ds2 = ds2.shuffle(buffer_size=2) | |||
| num_iter = 0 | |||
| for _ in ds1.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert num_iter == 12 | |||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | |||
| num_iter = 0 | |||
| for _ in ds2.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert num_iter == 3 | |||
| logger.info("test_cache_nomap_allowed_share1 Ended.\n") | |||
| def test_cache_nomap_allowed_share2(): | |||
| """ | |||
| It is allowed to share the cache between the following two trees (with map decode): | |||
| Repeat Shuffle | |||
| | | | |||
| Cache Cache | |||
| | | | |||
| Map(decode) Map(decode) | |||
| | | | |||
| TFReader TFReader | |||
| """ | |||
| logger.info("Test cache nomap allowed share 2") | |||
| ds.config.set_seed(1) | |||
| # This dataset has 3 records in it only | |||
| some_cache = ds.DatasetCache(session_id=2, size=0, spilling=True) | |||
| decode_op = c_vision.Decode() | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) | |||
| ds1 = ds1.repeat(4) | |||
| ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| ds2 = ds2.map(input_columns=["image"], operations=decode_op, cache=some_cache) | |||
| ds2 = ds2.shuffle(buffer_size=2) | |||
| num_iter = 0 | |||
| for _ in ds1.create_dict_iterator(): | |||
| num_iter += 1 | |||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | |||
| assert num_iter == 12 | |||
| num_iter = 0 | |||
| for _ in ds2.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert num_iter == 3 | |||
| logger.info("test_cache_nomap_allowed_share2 Ended.\n") | |||
| def test_cache_nomap_allowed_share3(): | |||
| """ | |||
| It is allowed to share the cache between the following two trees (different shard ids): | |||
| Repeat Repeat | |||
| | | | |||
| Cache Cache | |||
| | | | |||
| TFReader(shard_id = 0) TFReader(shard_id = 1) | |||
| """ | |||
| logger.info("Test cache nomap allowed share 3") | |||
| some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) | |||
| tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data"] | |||
| ds1 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=0, num_samples=3, shuffle=False, cache=some_cache) | |||
| ds1 = ds1.repeat(4) | |||
| ds2 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=1, num_samples=3, shuffle=False, cache=some_cache) | |||
| ds2 = ds2.repeat(4) | |||
| num_iter = 0 | |||
| for _ in ds1.create_dict_iterator(): | |||
| num_iter += 1 | |||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | |||
| assert num_iter == 12 | |||
| num_iter = 0 | |||
| for _ in ds2.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert num_iter == 12 | |||
| logger.info("test_cache_nomap_allowed_share3 Ended.\n") | |||
| def test_cache_nomap_disallowed_share1(): | |||
| """ | |||
| It is not allowed to share the cache between the following two trees: | |||
| Cache Cache | |||
| | | | |||
| Map(decode) Map(rescale) | |||
| | | | |||
| TFReader TFReader | |||
| """ | |||
| logger.info("Test cache nomap disallowed share1") | |||
| # This dataset has 3 records in it only | |||
| some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) | |||
| decode_op = c_vision.Decode() | |||
| rescale_op = c_vision.Rescale(1.0 / 255.0, -1.0) | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) | |||
| ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| ds2 = ds2.map(input_columns=["image"], operations=rescale_op, cache=some_cache) | |||
| num_iter = 0 | |||
| for _ in ds1.create_dict_iterator(): | |||
| num_iter += 1 | |||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | |||
| assert num_iter == 3 | |||
| try: | |||
| sum([1 for _ in ds2]) | |||
| except RuntimeError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "Attempt to re-use a cache for a different tree!" in str(e) | |||
| logger.info("test_cache_nomap_disallowed_share1 Ended.\n") | |||
| if __name__ == '__main__': | |||
| test_cache_nomap_basic1() | |||
| test_cache_nomap_basic2() | |||
| test_cache_nomap_basic3() | |||
| test_cache_nomap_basic4() | |||
| test_cache_nomap_basic5() | |||
| test_cache_nomap_basic6() | |||
| test_cache_nomap_basic7() | |||
| test_cache_nomap_allowed_share1() | |||
| test_cache_nomap_allowed_share2() | |||
| test_cache_nomap_allowed_share3() | |||
| test_cache_nomap_disallowed_share1() | |||
| @@ -16,17 +16,16 @@ import mindspore.common.dtype as mstype | |||
| import mindspore.dataset as ds | |||
| from mindspore import log as logger | |||
| # just a basic test with parallel random data op | |||
| def test_randomdataset_basic1(): | |||
| logger.info("Test randomdataset basic") | |||
| logger.info("Test randomdataset basic 1") | |||
| schema = ds.Schema() | |||
| schema.add_column('image', de_type=mstype.uint8, shape=[2]) | |||
| schema.add_column('label', de_type=mstype.uint8, shape=[1]) | |||
| # apply dataset operations | |||
| ds1 = ds.RandomDataset(schema=schema, num_samples=50, num_parallel_workers=4) | |||
| ds1 = ds.RandomDataset(schema=schema, total_rows=50, num_parallel_workers=4) | |||
| ds1 = ds1.repeat(4) | |||
| num_iter = 0 | |||
| @@ -36,8 +35,9 @@ def test_randomdataset_basic1(): | |||
| logger.info("{} label: {}".format(num_iter, data["label"])) | |||
| num_iter += 1 | |||
| logger.info("Number of data in ds1: ", num_iter) | |||
| logger.info("Number of data in ds1: {}".format(num_iter)) | |||
| assert num_iter == 200 | |||
| logger.info("Test randomdataset basic 1 complete") | |||
| # Another simple test | |||
| @@ -49,10 +49,8 @@ def test_randomdataset_basic2(): | |||
| shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image) | |||
| schema.add_column('label', de_type=mstype.uint8, shape=[1]) | |||
| # Make up about 10 samples | |||
| ds1 = ds.RandomDataset(schema=schema, num_samples=10, num_parallel_workers=1) | |||
| # cache size allows for about 4 images since each image just a bit less than 1MB, after that we will have to spill | |||
| # Make up 10 rows | |||
| ds1 = ds.RandomDataset(schema=schema, total_rows=10, num_parallel_workers=1) | |||
| ds1 = ds1.repeat(4) | |||
| num_iter = 0 | |||
| @@ -62,11 +60,31 @@ def test_randomdataset_basic2(): | |||
| logger.info("printing the label: {}".format(data["label"])) | |||
| num_iter += 1 | |||
| logger.info("Number of data in ds1: ", num_iter) | |||
| logger.info("Number of data in ds1: {}".format(num_iter)) | |||
| assert num_iter == 40 | |||
| logger.info("Test randomdataset basic 2 complete") | |||
| # Another simple test | |||
| def test_randomdataset_basic3(): | |||
| logger.info("Test randomdataset basic 3") | |||
| # Make up 10 samples, but here even the schema is randomly created | |||
| # The columns are named like this "c0", "c1", "c2" etc | |||
| # But, we will use a tuple iterator instead of dict iterator so the column names | |||
| # are not needed to iterate | |||
| ds1 = ds.RandomDataset(total_rows=10, num_parallel_workers=1) | |||
| ds1 = ds1.repeat(2) | |||
| num_iter = 0 | |||
| for _ in ds1.create_tuple_iterator(): | |||
| num_iter += 1 | |||
| logger.info("Number of data in ds1: {}".format(num_iter)) | |||
| assert num_iter == 20 | |||
| logger.info("Test randomdataset basic 3 Complete") | |||
| if __name__ == '__main__': | |||
| test_randomdataset_basic1() | |||
| test_randomdataset_basic2() | |||
| logger.info('test_randomdataset_basic Ended.\n') | |||
| test_randomdataset_basic3() | |||