Merge pull request !460 from xulei/filter_mastertags/v0.2.0-alpha
| @@ -29,6 +29,7 @@ | |||
| #include "dataset/engine/datasetops/source/cifar_op.h" | |||
| #include "dataset/engine/datasetops/source/celeba_op.h" | |||
| #include "dataset/engine/datasetops/source/text_file_op.h" | |||
| #include "dataset/engine/datasetops/filter_op.h" | |||
| #include "mindrecord/include/shard_category.h" | |||
| #include "mindrecord/include/shard_sample.h" | |||
| #include "mindrecord/include/shard_shuffle.h" | |||
| @@ -45,6 +46,7 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D | |||
| {kShuffle, &DEPipeline::ParseShuffleOp}, | |||
| {kMindrecord, &DEPipeline::ParseMindRecordOp}, | |||
| {kMap, &DEPipeline::ParseMapOp}, | |||
| {kFilter, &DEPipeline::ParseFilterOp}, | |||
| {kBatch, &DEPipeline::ParseBatchOp}, | |||
| {kRepeat, &DEPipeline::ParseRepeatOp}, | |||
| {kSkip, &DEPipeline::ParseSkipOp}, | |||
| @@ -502,6 +504,41 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> * | |||
| return Status::OK(); | |||
| } | |||
| Status DEPipeline::ParseFilterOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | |||
| std::shared_ptr<FilterOp::Builder> builder = std::make_shared<FilterOp::Builder>(); | |||
| if (args["predicate"].is_none()) { | |||
| RETURN_STATUS_UNEXPECTED("Error: 'predicate' is not set. \n"); | |||
| } | |||
| for (auto arg : args) { | |||
| std::string key = py::str(arg.first); | |||
| py::handle value = arg.second; | |||
| if (!value.is_none()) { | |||
| if (key == "num_parallel_workers") { | |||
| (void)builder->SetNumWorkers(ToInt(value)); | |||
| } else if (key == "predicate") { | |||
| py::handle op = args["predicate"]; | |||
| if (!py::isinstance<py::function>(op)) { | |||
| RETURN_STATUS_UNEXPECTED("Error: predicate is not recognised (not pyfunc)."); | |||
| } | |||
| py::function predicate_func = op.cast<py::function>(); | |||
| (void)builder->SetPredicateFunc(std::move(predicate_func)); | |||
| } else if (key == "input_columns") { | |||
| std::vector<std::string> in_col_names = ToStringVector(args["input_columns"]); | |||
| (void)builder->SetInColNames(in_col_names); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Error: Unhandled key: " + key); | |||
| } | |||
| } | |||
| } | |||
| std::shared_ptr<FilterOp> op; | |||
| RETURN_IF_NOT_OK(builder->Build(&op)); | |||
| *ptr = op; | |||
| return Status::OK(); | |||
| } | |||
| Status DEPipeline::ParseRepeatOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | |||
| if (args["count"].is_none()) { | |||
| std::string err_msg = "Error: count is invalid or not set."; | |||
| @@ -671,8 +708,6 @@ Status DEPipeline::ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> * | |||
| return Status::OK(); | |||
| } | |||
| DsOpPtr DEPipeline::ParseFilterOp(const py::dict &args) const { return DsOpPtr(); } | |||
| Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | |||
| // Required arguments | |||
| std::shared_ptr<TFReaderOp::Builder> builder = std::make_shared<TFReaderOp::Builder>(); | |||
| @@ -107,6 +107,8 @@ class DEPipeline { | |||
| Status ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| Status ParseFilterOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| Status ParseRepeatOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| Status ParseSkipOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| @@ -121,8 +123,6 @@ class DEPipeline { | |||
| Status ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| DsOpPtr ParseFilterOp(const py::dict &args) const; | |||
| Status ParseDeviceQueueOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| Status ParseTFReaderOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| @@ -31,6 +31,7 @@ | |||
| #include "dataset/engine/datasetops/map_op.h" | |||
| #include "dataset/engine/datasetops/project_op.h" | |||
| #include "dataset/engine/datasetops/rename_op.h" | |||
| #include "dataset/engine/datasetops/filter_op.h" | |||
| #include "dataset/engine/datasetops/repeat_op.h" | |||
| #include "dataset/engine/datasetops/skip_op.h" | |||
| #include "dataset/engine/datasetops/shuffle_op.h" | |||
| @@ -240,7 +240,7 @@ void Tensor::PrintItemAt(const std::vector<dsize_t> &index, std::ostream &out) c | |||
| DS_ASSERT(data_); | |||
| switch (type_.value()) { | |||
| CASE_PRINT_HEX(DataType::DE_BOOL, uint8_t); | |||
| CASE_PRINT_HEX(DataType::DE_BOOL, bool); | |||
| CASE_PRINT_HEX(DataType::DE_INT8, int8_t); | |||
| @@ -14,5 +14,6 @@ add_library(engine-datasetops OBJECT | |||
| take_op.cc | |||
| shuffle_op.cc | |||
| zip_op.cc | |||
| filter_op.cc | |||
| ) | |||
| @@ -0,0 +1,273 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/engine/datasetops/filter_op.h" | |||
| #include <algorithm> | |||
| #include <cstring> | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "dataset/core/config_manager.h" | |||
| #include "dataset/core/constants.h" | |||
| #include "dataset/core/global_context.h" | |||
| #include "dataset/core/tensor.h" | |||
| #include "dataset/engine/data_buffer.h" | |||
| #include "dataset/engine/db_connector.h" | |||
| #include "dataset/engine/execution_tree.h" | |||
| #include "dataset/kernels/tensor_op.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "dataset/util/task_manager.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| Status FilterOp::Builder::SanityCheck() { | |||
| std::string err; | |||
| err += builder_op_connector_size_ <= 0 ? "connector size <= 0\n" : ""; | |||
| err += builder_num_workers_ <= 0 ? "filter num_parallel_workers <= 0\n" : ""; | |||
| return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, common::SafeCStr(err)); | |||
| } | |||
| FilterOp::Builder::Builder() { | |||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||
| builder_num_workers_ = cfg->num_parallel_workers(); | |||
| builder_op_connector_size_ = cfg->op_connector_size(); | |||
| } | |||
| Status FilterOp::Builder::Build(std::shared_ptr<FilterOp> *ptr) { | |||
| RETURN_IF_NOT_OK(SanityCheck()); | |||
| *ptr = std::make_shared<FilterOp>(std::move(build_in_col_names_), builder_num_workers_, builder_op_connector_size_, | |||
| builder_predicate_func_); | |||
| return Status::OK(); | |||
| } | |||
| FilterOp::FilterOp(const std::vector<std::string> &in_col_names, int32_t num_workers, int32_t op_queue_size, | |||
| py::function predicate_func) | |||
| : ParallelOp(num_workers, op_queue_size), predicate_func_(std::move(predicate_func)), in_columns_(in_col_names) {} | |||
| Status FilterOp::operator()() { | |||
| // The operator class just starts off threads by calling the tree_ function. | |||
| RETURN_UNEXPECTED_IF_NULL(tree_); | |||
| // Synchronize with TaskManager. | |||
| TaskManager::FindMe()->Post(); | |||
| filter_queues_.Init(num_workers_, oc_queue_size_); | |||
| RETURN_IF_NOT_OK(filter_queues_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&FilterOp::WorkerEntry, this, std::placeholders::_1))); | |||
| RETURN_IF_NOT_OK(Collector()); | |||
| return Status::OK(); | |||
| } | |||
| Status FilterOp::EofReceived(int32_t) { return Status::OK(); } | |||
| Status FilterOp::EoeReceived(int32_t) { return Status::OK(); } | |||
| // Validating if each of the input_columns exists in the DataBuffer. | |||
| Status FilterOp::ValidateInColumns(const std::unordered_map<std::string, int32_t> &col_name_id_map, | |||
| std::vector<std::string> *input_columns) { | |||
| for (const auto &inCol : *input_columns) { | |||
| bool found = col_name_id_map.find(inCol) != col_name_id_map.end() ? true : false; | |||
| if (!found) { | |||
| std::string err_msg = "input column name: " + inCol + " doesn't exist in the dataset columns."; | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // A print method typically used for debugging. | |||
| void FilterOp::Print(std::ostream &out, bool show_all) const { | |||
| // Call base class printer first. | |||
| ParallelOp::Print(out, show_all); | |||
| // Then display our own stuff. | |||
| out << "\nFilterOp:"; | |||
| out << "\n Input column names:"; | |||
| for (size_t i = 0; i < in_columns_.size(); i++) { | |||
| out << " " << in_columns_[i]; | |||
| } | |||
| } | |||
| Status FilterOp::WorkerEntry(int32_t worker_id) { | |||
| // Handshake with TaskManager that thread creation is successful. | |||
| TaskManager::FindMe()->Post(); | |||
| std::unique_ptr<DataBuffer> in_buffer; | |||
| bool worker_stop = false; | |||
| while (worker_stop == false) { | |||
| // Getting a databuffer to work on. | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&in_buffer, worker_id)); | |||
| if (in_buffer->eoe()) { | |||
| filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEoe)); | |||
| continue; | |||
| } else if (in_buffer->eof()) { | |||
| filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEof)); | |||
| worker_stop = true; | |||
| continue; | |||
| } | |||
| // Thread local variables to avoid lock. When in_columns_ is empty and workers will write | |||
| // the name of the first column into input_columns (thread local) instead of in_columns_ (thread global). | |||
| std::vector<std::string> input_columns = in_columns_; | |||
| // Indices of the columns to process. | |||
| std::vector<size_t> to_process_indices; | |||
| RETURN_IF_NOT_OK(WorkerEntryInit(in_buffer.get(), &to_process_indices, &input_columns)); | |||
| // if the databuffer was all filtered, it is marked as kFilterEmpty. | |||
| // if the databuffer was partially filtered, it is marked as kFilterPartial. | |||
| // if the databuffer was not filtered, it is marked as kFilterFull. | |||
| int32_t num_rows = in_buffer->NumRows(); | |||
| std::unique_ptr<TensorQTable> new_tensor_table; | |||
| RETURN_IF_NOT_OK(WorkerCompute(in_buffer.get(), to_process_indices, &new_tensor_table)); | |||
| if (new_tensor_table->empty()) { | |||
| RETURN_IF_NOT_OK( | |||
| filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEmpty))); | |||
| } else if (new_tensor_table->size() == num_rows) { | |||
| in_buffer->set_tensor_table(std::move(new_tensor_table)); | |||
| RETURN_IF_NOT_OK( | |||
| filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterFull))); | |||
| } else { // kFilterPartial | |||
| in_buffer->set_tensor_table(std::move(new_tensor_table)); | |||
| RETURN_IF_NOT_OK( | |||
| filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterPartial))); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status FilterOp::WorkerCompute(DataBuffer *in_buffer, const std::vector<size_t> &to_proess_indices, | |||
| std::unique_ptr<TensorQTable> *out) { | |||
| *out = std::make_unique<TensorQTable>(); | |||
| int32_t num_rows = in_buffer->NumRows(); | |||
| for (int32_t i = 0; i < num_rows; i++) { | |||
| TensorRow to_process; | |||
| TensorRow cur_row; | |||
| RETURN_IF_NOT_OK(in_buffer->PopRow(&cur_row)); | |||
| (void)std::transform(to_proess_indices.begin(), to_proess_indices.end(), std::back_inserter(to_process), | |||
| [&cur_row](const size_t &it) -> std::shared_ptr<Tensor> { return cur_row[it]; }); | |||
| bool predicate = true; | |||
| RETURN_IF_NOT_OK(InvokePredicateFunc(to_process, &predicate)); | |||
| if (predicate) { | |||
| (*out)->push_back(std::move(cur_row)); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // if the filtered DataBuffer is written directly to out_connector_, | |||
| // the thread fetching data will block in a queue. | |||
| // Collector function will reorder the DataBuffer in order. | |||
| // for example in two work queues: | |||
| // int filter_queues_: | |||
| // queue1: DB(data1 kFilterEmpty) DB(eoe) DB(data4) DB(eof) | |||
| // queue2: DB(data2) DB(data3 kFilterEmpty) DB(eoe) | |||
| // after reorder in out_connector_: | |||
| // queue1: DB(data2) DB(data4) DB(eof) | |||
| // queue2: DB(eoe) DB(eoe) | |||
| Status FilterOp::Collector() { | |||
| bool collector_stop = false; | |||
| uint64_t task_id_cnt = 0; | |||
| uint64_t out_id_cnt = 0; | |||
| std::pair<std::unique_ptr<DataBuffer>, filterCtrl> in_pair; | |||
| while (collector_stop == false) { | |||
| uint32_t w_id = task_id_cnt % num_workers_; | |||
| RETURN_IF_NOT_OK(filter_queues_[w_id]->PopFront(&in_pair)); | |||
| if (in_pair.second == filterCtrl::kFilterFull || in_pair.second == filterCtrl::kFilterPartial || | |||
| in_pair.second == filterCtrl::kFilterEoe) { | |||
| uint32_t out_task_id = out_id_cnt % num_workers_; | |||
| RETURN_IF_NOT_OK(out_connector_->Add(static_cast<int>(out_task_id), std::move(in_pair.first))); | |||
| out_id_cnt++; | |||
| task_id_cnt++; | |||
| } else if (in_pair.second == filterCtrl::kFilterEof) { | |||
| uint32_t out_task_id = out_id_cnt % num_workers_; | |||
| RETURN_IF_NOT_OK(out_connector_->Add(static_cast<int>(out_task_id), std::move(in_pair.first))); | |||
| collector_stop = true; | |||
| } else { // kFilterEmpty | |||
| task_id_cnt++; | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // initialize some internal data structure used by WorkerEntry(). | |||
| Status FilterOp::WorkerEntryInit(const DataBuffer *in_buf, std::vector<size_t> *to_process_indices, | |||
| std::vector<std::string> *input_columns) { | |||
| int32_t num_rows = in_buf->NumRows(); | |||
| int32_t num_cols = in_buf->NumCols(); | |||
| if (num_rows == 0 || num_cols == 0) { | |||
| RETURN_STATUS_UNEXPECTED("FilterOp is getting an empty DataBuffer."); | |||
| } | |||
| std::unordered_map<std::string, int32_t> col_name_id_map = in_buf->column_name_map(); | |||
| // Check if there is invalid column name in the inColumns. | |||
| RETURN_IF_NOT_OK(ValidateInColumns(col_name_id_map, input_columns)); | |||
| if (input_columns->empty()) { | |||
| MS_LOG(INFO) << "Input columns in filter operator is empty, will apply to the all column in the current table."; | |||
| // sort the input colunms by column index. | |||
| std::vector<std::pair<std::string, int32_t>> sort_vec(col_name_id_map.begin(), col_name_id_map.end()); | |||
| std::sort(sort_vec.begin(), sort_vec.end(), | |||
| [](const std::pair<std::string, int32_t> &a, const std::pair<std::string, int32_t> &b) { | |||
| return a.second < b.second; | |||
| }); | |||
| (void)std::transform(sort_vec.begin(), sort_vec.end(), std::back_inserter(*input_columns), | |||
| [](const auto &it) -> std::string { return it.first; }); | |||
| } | |||
| // initialize to_process_indices. | |||
| (void)std::transform(input_columns->begin(), input_columns->end(), std::back_inserter(*to_process_indices), | |||
| [&col_name_id_map](const auto &it) -> size_t { return col_name_id_map[it]; }); | |||
| return Status::OK(); | |||
| } | |||
| Status FilterOp::CheckInput(const TensorRow &input) const { | |||
| for (auto &item : input) { | |||
| if (item == nullptr) { | |||
| RETURN_STATUS_UNEXPECTED("input is null."); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status FilterOp::InvokePredicateFunc(const TensorRow &input, bool *out_predicate) { | |||
| RETURN_IF_NOT_OK(CheckInput(input)); | |||
| // Acquire Python GIL. | |||
| py::gil_scoped_acquire gil_acquire; | |||
| if (Py_IsInitialized() == 0) { | |||
| return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); | |||
| } | |||
| try { | |||
| // Transform input tensor vector into numpy array vector. | |||
| py::tuple input_args(input.size()); | |||
| for (size_t i = 0; i < input.size(); i++) { | |||
| py::array new_data; | |||
| RETURN_IF_NOT_OK(input.at(i)->GetDataAsNumpy(&new_data)); | |||
| input_args[i] = new_data; | |||
| } | |||
| // Invoke python function. | |||
| py::object ret_py_obj = predicate_func_(*input_args); | |||
| *out_predicate = ret_py_obj.cast<py::bool_>(); | |||
| } catch (const py::error_already_set &e) { | |||
| std::stringstream ss; | |||
| ss << e.what() << std::endl; | |||
| ss << "The type of the return value of python predicate function is not bool, or can not be convert to bool."; | |||
| return Status(StatusCode::kPyFuncException, ss.str()); | |||
| } | |||
| return Status(StatusCode::kOK, "FilterOp predicate func call succeed"); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,180 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_DATASETOPS_FILTER_OP_H_ | |||
| #define DATASET_ENGINE_DATASETOPS_FILTER_OP_H_ | |||
| #include <memory> | |||
| #include <queue> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "dataset/engine/datasetops/parallel_op.h" | |||
| #include "dataset/kernels/tensor_op.h" | |||
| #include "dataset/util/queue.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class FilterOp : public ParallelOp { | |||
| public: | |||
| // The nested builder class inside of the FilterOp is used to help manage all of | |||
| // the arguments for constructing it. Use the builder by setting each argument | |||
| // with the provided set methods, and then finally call the build method to execute | |||
| // the actual construction. | |||
| class Builder { | |||
| public: | |||
| // Builder constructor. Creates the builder object. | |||
| // @note No default args. | |||
| // @return This is a constructor. | |||
| Builder(); | |||
| // Default destructor | |||
| ~Builder() = default; | |||
| // Setter method. | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetPredicateFunc(py::function func) { | |||
| builder_predicate_func_ = std::move(func); | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetInColNames(const std::vector<std::string> &in_col_names) { | |||
| build_in_col_names_ = in_col_names; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetNumWorkers(int32_t num_workers) { | |||
| builder_num_workers_ = num_workers; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetOpConnectorSize(int32_t connector_size) { | |||
| builder_op_connector_size_ = connector_size; | |||
| return *this; | |||
| } | |||
| // The builder "build" method creates the final object. | |||
| // @param ptr The shared_ptr to the new FilterOp object. | |||
| // @return Status. | |||
| Status Build(std::shared_ptr<FilterOp> *ptr); | |||
| private: | |||
| // Sanity check for builder class args. | |||
| // @return Status - The error code return. | |||
| Status SanityCheck(); | |||
| std::vector<std::string> build_in_col_names_; | |||
| py::function builder_predicate_func_; | |||
| int32_t builder_num_workers_; | |||
| int32_t builder_op_connector_size_; | |||
| }; | |||
| enum filterCtrl : int8_t { kFilterEmpty = 0, kFilterPartial = 1, kFilterFull = 2, kFilterEoe = 3, kFilterEof = 4 }; | |||
| // Constructor of FilterOp | |||
| // @note The builder class should be used to call it. | |||
| // @param in_col_names A list of input column names,when it is empty the predicate will be | |||
| // applied all columns in the dataset. | |||
| // @param num_workers The number of worker threads. | |||
| // @param op_connector_size The size of each queue in the connector. | |||
| // @param predicate_func python callable which returns a boolean value. | |||
| FilterOp(const std::vector<std::string> &in_col_names, int32_t num_workers, int32_t op_queue_size, | |||
| py::function predicate_func); | |||
| // Class functor operator () override. | |||
| // All dataset ops operate by launching a thread (see ExecutionTree),This class functor will | |||
| // provide the master loop that drives the logic for performing the work. | |||
| // @return Status The error code return | |||
| Status operator()() override; | |||
| // @param int32_t workerId. | |||
| // @return Status - The error code return. | |||
| Status EofReceived(int32_t) override; | |||
| // @param int32_t workerId. | |||
| // @return Status - The error code return. | |||
| Status EoeReceived(int32_t) override; | |||
| // A print method typically used for debugging. | |||
| // @param out The output stream to write output to. | |||
| // @param show_all A bool to control if you want to show all info or just a summary. | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| private: | |||
| // predicate_func python callable which returns a boolean value. | |||
| py::function predicate_func_; | |||
| // Variable to store the column name that will feed to predicate function. | |||
| std::vector<std::string> in_columns_; | |||
| // Internal queue for filter. | |||
| QueueList<std::pair<std::unique_ptr<DataBuffer>, filterCtrl>> filter_queues_; | |||
| // Private function for worker/thread to loop continuously. It comprises the main | |||
| // logic of FilterOp, getting the data from previous Op, validating user specified column names, | |||
| // applying predicate to each of the data, filter the data when predicate result is false. | |||
| // @param worker_id The id assigned to this thread/worker upon creation. | |||
| // @return Status The error code return. | |||
| Status WorkerEntry(int32_t worker_id) override; // In: workerId assigned by tree_ | |||
| // Filter the data by predicate function . | |||
| // @param in_buffer input data buffer. | |||
| // @param to_proess_indices Indices of columns to be processed. | |||
| // @param out data buffer that are filtered by predicate. | |||
| // @return Status The error code return. | |||
| Status WorkerCompute(DataBuffer *in_buffer, const std::vector<size_t> &to_proess_indices, | |||
| std::unique_ptr<TensorQTable> *out); | |||
| // Collector databuffer. | |||
| // @return Status The error code return. | |||
| Status Collector(); | |||
| // @param input tensor vector. | |||
| // @return Status - The error code return. | |||
| Status CheckInput(const TensorRow &input) const; | |||
| // Invoke python func. | |||
| // @param input tensor vector. | |||
| // @param the result of predicate. | |||
| // @return Status - The error code return. | |||
| Status InvokePredicateFunc(const TensorRow &input, bool *out_predicate); | |||
| // Private function for validating if each of the user specified input column names | |||
| // exist in the DataBuffer. | |||
| // @param col_name_id_map The column name to index mapping obtained from DataBuffer. | |||
| // @param input_columns The vector of input column names used in the current thread. | |||
| // @return Status The error code return. | |||
| Status ValidateInColumns(const std::unordered_map<std::string, int32_t> &col_name_id_map, | |||
| std::vector<std::string> *input_columns); | |||
| // Private function that initialize some internal data structure used by WorkerEntry(). | |||
| // @param in_buf A raw pointer to the DataBuffer. A raw pointer is fine because this function does not manage memory | |||
| // and is not shared with other threads. | |||
| // @param[out] to_process_indices Indices of columns that will feed to predicate. | |||
| // @param input_columns The vector of input column names used in the current thread. | |||
| Status WorkerEntryInit(const DataBuffer *in_buf, std::vector<size_t> *to_process_indices, | |||
| std::vector<std::string> *input_columns); | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif | |||
| @@ -35,7 +35,7 @@ from mindspore._c_expression import typing | |||
| from mindspore import log as logger | |||
| from . import samplers | |||
| from .iterators import DictIterator, TupleIterator | |||
| from .validators import check, check_batch, check_shuffle, check_map, check_repeat, check_skip, check_zip, check_rename, \ | |||
| from .validators import check, check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, check_rename, \ | |||
| check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ | |||
| check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \ | |||
| check_zip_dataset, check_add_column, check_textfiledataset | |||
| @@ -385,6 +385,32 @@ class Dataset: | |||
| """ | |||
| return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers) | |||
| @check_filter | |||
| def filter(self, predicate, input_columns=None, num_parallel_workers=1): | |||
| """ | |||
| Filter dataset by predicate. | |||
| Note: | |||
| If input_columns not provided or empty, all columns will be used. | |||
| Args: | |||
| predicate: python callable which returns a boolean value. | |||
| input_columns: (list[str]): List of names of the input columns, when | |||
| default=None, the predicate will be applied on all columns in the dataset. | |||
| num_parallel_workers (int, optional): Number of workers to process the Dataset | |||
| in parallel (default=None). | |||
| Returns: | |||
| FilterDataset, dataset filter. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| >>> # generator data(0 ~ 63) | |||
| >>> # filter the data that greater than or equal to 11 | |||
| >>> dataset_f = dataset.filter(predicate=lambda data: data < 11, input_columns = ["data"]) | |||
| """ | |||
| return FilterDataset(self, predicate, input_columns, num_parallel_workers) | |||
| @check_repeat | |||
| def repeat(self, count=None): | |||
| """ | |||
| @@ -1105,6 +1131,44 @@ class MapDataset(DatasetOp): | |||
| return self.input[0].get_dataset_size() | |||
| class FilterDataset(DatasetOp): | |||
| """ | |||
| The result of applying filter predicate to the input Dataset. | |||
| Args: | |||
| input_dataset: Input Dataset to be mapped. | |||
| predicate: python callable which returns a boolean value. | |||
| input_columns: (list[str]): List of names of the input columns, when | |||
| default=None, the predicate will be applied all columns in the dataset. | |||
| num_parallel_workers (int, optional): Number of workers to process the Dataset | |||
| in parallel (default=None). | |||
| """ | |||
| def __init__(self, input_dataset, predicate, input_columns=None, num_parallel_workers=None): | |||
| super().__init__(num_parallel_workers) | |||
| self.predicate = lambda *args: bool(predicate(*args)) | |||
| self.input.append(input_dataset) | |||
| input_dataset.output.append(self) | |||
| if input_columns is not None and not isinstance(input_columns, list): | |||
| input_columns = [input_columns] | |||
| self.input_columns = input_columns | |||
| def get_args(self): | |||
| args = super().get_args() | |||
| args["predicate"] = self.predicate | |||
| args["input_columns"] = self.input_columns | |||
| return args | |||
| def get_dataset_size(self): | |||
| """ | |||
| Get the number of batches in an epoch. | |||
| the size cannot be determined before we run the pipeline | |||
| Return: | |||
| 0 | |||
| """ | |||
| return 0 | |||
| class RepeatDataset(DatasetOp): | |||
| """ | |||
| The result of applying Repeat operator to the input Dataset. | |||
| @@ -129,6 +129,8 @@ class Iterator: | |||
| op_type = OpName.ZIP | |||
| elif isinstance(dataset, de.MapDataset): | |||
| op_type = OpName.MAP | |||
| elif isinstance(dataset, de.FilterDataset): | |||
| op_type = OpName.FILTER | |||
| elif isinstance(dataset, de.RepeatDataset): | |||
| op_type = OpName.REPEAT | |||
| elif isinstance(dataset, de.SkipDataset): | |||
| @@ -693,6 +693,26 @@ def check_map(method): | |||
| return new_method | |||
| def check_filter(method): | |||
| """"check the input arguments of filter.""" | |||
| @wraps(method) | |||
| def new_method(*args, **kwargs): | |||
| param_dict = make_param_dict(method, args, kwargs) | |||
| predicate = param_dict.get("predicate") | |||
| if not callable(predicate): | |||
| raise ValueError("Predicate should be a python function or a callable python object.") | |||
| nreq_param_int = ['num_parallel_workers'] | |||
| check_param_type(nreq_param_int, param_dict, int) | |||
| param_name = "input_columns" | |||
| param = param_dict.get(param_name) | |||
| if param is not None: | |||
| check_columns(param, param_name) | |||
| return method(*args, **kwargs) | |||
| return new_method | |||
| def check_repeat(method): | |||
| """check the input arguments of repeat.""" | |||
| @wraps(method) | |||
| @@ -66,6 +66,8 @@ SET(DE_UT_SRCS | |||
| celeba_op_test.cc | |||
| take_op_test.cc | |||
| text_file_op_test.cc) | |||
| filter_op_test.cc | |||
| ) | |||
| add_executable(de_ut_tests ${DE_UT_SRCS}) | |||
| @@ -0,0 +1,53 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/util/circular_pool.h" | |||
| #include "dataset/core/client.h" | |||
| #include "common/common.h" | |||
| #include "gtest/gtest.h" | |||
| #include "utils/log_adapter.h" | |||
| using namespace mindspore::dataset; | |||
| namespace de = mindspore::dataset; | |||
| using mindspore::MsLogLevel::INFO; | |||
| using mindspore::ExceptionType::NoExceptionType; | |||
| using mindspore::LogStream; | |||
| class MindDataTestfilter_op : public UT::DatasetOpTesting { | |||
| }; | |||
| std::shared_ptr<de::FilterOp> Filter() { | |||
| Status rc; | |||
| std::shared_ptr<de::FilterOp> op; | |||
| rc = de::FilterOp::Builder().Build(&op); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| return op; | |||
| } | |||
| TEST_F(MindDataTestfilter_op, Testfilter_opFuntions) { | |||
| MS_LOG(INFO) << "Doing MindDataTest filter_op."; | |||
| auto my_tree = std::make_shared<ExecutionTree>(); | |||
| std::shared_ptr<DatasetOp> parent_op = Filter(); | |||
| std::shared_ptr<DatasetOp> leaf_op = Filter(); | |||
| my_tree->AssociateNode(parent_op); | |||
| my_tree->AssociateNode(leaf_op); | |||
| ASSERT_NE(parent_op, nullptr); | |||
| ASSERT_NE(leaf_op, nullptr); | |||
| } | |||
| @@ -158,6 +158,16 @@ TEST_F(MindDataTestTensorDE, InsertTensor) { | |||
| ASSERT_EQ(*t == *t6, true); | |||
| } | |||
| // Test the bug of Tensor::ToString will exec failed for Tensor which store bool values | |||
| TEST_F(MindDataTestTensorDE, BoolTensor) { | |||
| std::shared_ptr<Tensor> t = std::make_shared<Tensor>(TensorShape({2}), | |||
| DataType(DataType::DE_BOOL)); | |||
| t->SetItemAt<bool>({0}, true); | |||
| t->SetItemAt<bool>({1}, true); | |||
| std::string out = t->ToString(); | |||
| ASSERT_TRUE(out.find("Template type and Tensor type are not compatible") == std::string::npos); | |||
| } | |||
| TEST_F(MindDataTestTensorDE, GetItemAt) { | |||
| std::shared_ptr<Tensor> t = std::make_shared<Tensor>(TensorShape({2, 2}), DataType(DataType::DE_UINT8)); | |||
| t->Fill<uint8_t>(254); | |||
| @@ -0,0 +1,3 @@ | |||
| { | |||
| "rowsPerBuffer": 10, | |||
| } | |||
| @@ -0,0 +1,504 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| import numpy as np | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.c_transforms as cde | |||
| import mindspore.dataset.transforms.c_transforms as C | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore import log as logger | |||
| DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | |||
| SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | |||
| # test for predicate | |||
| def test_diff_predicate_func(): | |||
| def test_filter(predicate_func): | |||
| transforms = [ | |||
| cde.Decode(), | |||
| cde.Resize([64, 64]) | |||
| ] | |||
| type_cast_op = C.TypeCast(mstype.int32) | |||
| dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image", "label"], shuffle=False) | |||
| dataset = dataset.map(input_columns=["image"], operations=transforms, num_parallel_workers=1) | |||
| dataset = dataset.filter(input_columns=["image", "label"], predicate=predicate_func, num_parallel_workers=4) | |||
| num_iter = 0 | |||
| label_list = [] | |||
| for data in dataset.create_dict_iterator(): | |||
| num_iter += 1 | |||
| ori_img = data["image"] | |||
| label = data["label"] | |||
| label_list.append(label) | |||
| assert num_iter == 1 | |||
| assert label_list[0] == 3 | |||
| test_filter(lambda image, label: label == 3) | |||
| test_filter(lambda image, label: label[0] == 3) | |||
| test_filter(lambda image, label: label == [3]) | |||
| test_filter(lambda image, label: label == np.array([3])) | |||
| test_filter(lambda image, label: label == np.array(3)) | |||
| def filter_func_ge(data): | |||
| if data > 10: | |||
| return False | |||
| return True | |||
| def generator_1d(): | |||
| for i in range(64): | |||
| yield (np.array(i),) | |||
| # test with GeneratorDataset | |||
| def test_filter_by_generator_with_no(): | |||
| dataset = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| dataset_f = dataset.filter(predicate=lambda data: data < 11, num_parallel_workers=4) | |||
| num_iter = 0 | |||
| expected_rs = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] | |||
| for item in dataset_f.create_dict_iterator(): | |||
| assert item["data"] == expected_rs[num_iter] | |||
| num_iter += 1 | |||
| # test with repeatOp before | |||
| def test_filter_by_generator_with_repeat(): | |||
| dataset = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| dataset_r = dataset.repeat(4) | |||
| dataset_f = dataset_r.filter(predicate=filter_func_ge, num_parallel_workers=4) | |||
| num_iter = 0 | |||
| ret_data = [] | |||
| expected_rs = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] | |||
| for item in dataset_f.create_dict_iterator(): | |||
| num_iter += 1 | |||
| ret_data.append(item["data"]) | |||
| assert num_iter == 44 | |||
| for i in range(4): | |||
| for ii in range(len(expected_rs)): | |||
| index = i * len(expected_rs) + ii | |||
| assert ret_data[index] == expected_rs[ii] | |||
| # test with repeatOp after | |||
| def test_filter_by_generator_with_repeat_after(): | |||
| dataset = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| dataset_f = dataset.filter(predicate=filter_func_ge, num_parallel_workers=4) | |||
| dataset_r = dataset_f.repeat(4) | |||
| num_iter = 0 | |||
| ret_data = [] | |||
| expected_rs = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] | |||
| for item in dataset_r.create_dict_iterator(): | |||
| num_iter += 1 | |||
| ret_data.append(item["data"]) | |||
| assert num_iter == 44 | |||
| for i in range(4): | |||
| for ii in range(len(expected_rs)): | |||
| index = i * len(expected_rs) + ii | |||
| assert ret_data[index] == expected_rs[ii] | |||
| def filter_func_batch(data): | |||
| if data[0] > 8: | |||
| return False | |||
| return True | |||
| def filter_func_batch_after(data): | |||
| if data > 20: | |||
| return False | |||
| return True | |||
| # test with batchOp before | |||
| def test_filter_by_generator_with_batch(): | |||
| dataset = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| dataset_b = dataset.batch(4) | |||
| dataset_f = dataset_b.filter(predicate=filter_func_batch, num_parallel_workers=4) | |||
| num_iter = 0 | |||
| ret_data = [] | |||
| for item in dataset_f.create_dict_iterator(): | |||
| num_iter += 1 | |||
| ret_data.append(item["data"]) | |||
| assert num_iter == 3 | |||
| assert ret_data[0][0] == 0 | |||
| assert ret_data[1][0] == 4 | |||
| assert ret_data[2][0] == 8 | |||
| # test with batchOp after | |||
| def test_filter_by_generator_with_batch_after(): | |||
| dataset = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| dataset_f = dataset.filter(predicate=filter_func_batch_after, num_parallel_workers=4) | |||
| dataset_b = dataset_f.batch(4) | |||
| num_iter = 0 | |||
| ret_data = [] | |||
| for item in dataset_b.create_dict_iterator(): | |||
| num_iter += 1 | |||
| ret_data.append(item["data"]) | |||
| assert num_iter == 6 | |||
| assert ret_data[0][0] == 0 | |||
| assert ret_data[1][0] == 4 | |||
| assert ret_data[5][0] == 20 | |||
| def filter_func_shuffle(data): | |||
| if data > 20: | |||
| return False | |||
| return True | |||
| # test with batchOp before | |||
| def test_filter_by_generator_with_shuffle(): | |||
| dataset = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| dataset_s = dataset.shuffle(4) | |||
| dataset_f = dataset_s.filter(predicate=filter_func_shuffle, num_parallel_workers=4) | |||
| num_iter = 0 | |||
| for item in dataset_f.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert num_iter == 21 | |||
| def filter_func_shuffle_after(data): | |||
| if data > 20: | |||
| return False | |||
| return True | |||
| # test with batchOp after | |||
| def test_filter_by_generator_with_shuffle_after(): | |||
| dataset = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| dataset_f = dataset.filter(predicate=filter_func_shuffle_after, num_parallel_workers=4) | |||
| dataset_s = dataset_f.shuffle(4) | |||
| num_iter = 0 | |||
| for item in dataset_s.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert num_iter == 21 | |||
| def generator_1d_zip1(): | |||
| for i in range(64): | |||
| yield (np.array(i),) | |||
| def generator_1d_zip2(): | |||
| for i in range(64): | |||
| yield (np.array(i+100),) | |||
| def filter_func_zip(data1, data2): | |||
| if data1 > 20: | |||
| return False | |||
| return True | |||
| def filter_func_zip_after(data1): | |||
| if data1 > 20: | |||
| return False | |||
| return True | |||
| # test with zipOp before | |||
| def test_filter_by_generator_with_zip(): | |||
| dataset1 = ds.GeneratorDataset(generator_1d_zip1, ["data1"]) | |||
| dataset2 = ds.GeneratorDataset(generator_1d_zip2, ["data2"]) | |||
| dataz = ds.zip((dataset1, dataset2)) | |||
| dataset_f = dataz.filter(predicate=filter_func_zip, num_parallel_workers=1) | |||
| num_iter = 0 | |||
| ret_data = [] | |||
| for item in dataset_f.create_dict_iterator(): | |||
| num_iter += 1 | |||
| ret_data.append({"data1": item["data1"], "data2":item["data2"]}) | |||
| assert num_iter == 21 | |||
| assert ret_data[0]["data1"] == 0 | |||
| assert ret_data[0]["data2"] == 100 | |||
| assert ret_data[5]["data1"] == 5 | |||
| assert ret_data[5]["data2"] == 105 | |||
| # test with zipOp after | |||
| def test_filter_by_generator_with_zip_after(): | |||
| dataset1 = ds.GeneratorDataset(generator_1d_zip1, ["data1"]) | |||
| dataset2 = ds.GeneratorDataset(generator_1d_zip1, ["data2"]) | |||
| dt1 = dataset1.filter(predicate=filter_func_zip_after, num_parallel_workers=4) | |||
| dt2 = dataset2.filter(predicate=filter_func_zip_after, num_parallel_workers=4) | |||
| dataz = ds.zip((dt1, dt2)) | |||
| num_iter = 0 | |||
| ret_data = [] | |||
| for item in dataz.create_dict_iterator(): | |||
| num_iter += 1 | |||
| ret_data.append({"data1": item["data1"], "data2":item["data2"]}) | |||
| assert num_iter == 21 | |||
| assert ret_data[0]["data1"] == 0 | |||
| assert ret_data[0]["data2"] == 0 | |||
| assert ret_data[5]["data1"] == 5 | |||
| assert ret_data[5]["data2"] == 5 | |||
| def filter_func_map(col1, col2): | |||
| if col1[0] > 8: | |||
| return True | |||
| return False | |||
| def filter_func_map_part(col1): | |||
| if col1 < 3: | |||
| return True | |||
| else: | |||
| return False | |||
| def filter_func_map_all(col1, col2): | |||
| return True | |||
| def generator_mc(maxid=20): | |||
| for i in range(maxid): | |||
| yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]])) | |||
| def func_map(data_col1, data_col2): | |||
| return (data_col1, data_col2) | |||
| def func_map_part(data_col1): | |||
| return (data_col1) | |||
| # test with map | |||
| def test_filter_by_generator_with_map_all_col(): | |||
| dataset = ds.GeneratorDataset(generator_mc(12), ["col1", "col2"]) | |||
| dataset_map = dataset.map( input_columns=["col1"], output_columns=["col1"] , operations=func_map_part) | |||
| # dataset_map = dataset.map( operations=func_map_part) | |||
| dataset_f = dataset_map.filter(input_columns=["col1"], predicate=filter_func_map_part, num_parallel_workers=1) | |||
| num_iter = 0 | |||
| ret_data = [] | |||
| for item in dataset_f.create_dict_iterator(): | |||
| num_iter += 1 | |||
| ret_data.append(item["col1"]) | |||
| assert num_iter == 3 | |||
| assert ret_data[0] == 0 | |||
| assert ret_data[1] == 1 | |||
| # test with map | |||
| def test_filter_by_generator_with_map_part_col(): | |||
| dataset = ds.GeneratorDataset(generator_mc(12), ["col1", "col2"]) | |||
| dataset_map = dataset.map( input_columns=["col1"], output_columns=["out1"] , operations=func_map_part) | |||
| dataset_f = dataset_map.filter(input_columns=["out1", "col2"], predicate=filter_func_map, num_parallel_workers=4) | |||
| num_iter = 0 | |||
| ret_data = [] | |||
| for item in dataset_f.create_dict_iterator(): | |||
| num_iter += 1 | |||
| print(item) | |||
| ret_data.append(item["out1"]) | |||
| assert num_iter == 3 | |||
| assert ret_data[0] == 9 | |||
| assert ret_data[2] == 11 | |||
| def filter_func_rename(data): | |||
| if data> 8: | |||
| return True | |||
| return False | |||
| # test with rename before | |||
| def test_filter_by_generator_with_rename(): | |||
| dataset = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| dataset_b = dataset.rename(input_columns=["data"], output_columns=["col1"]) | |||
| dataset_f = dataset_b.filter(predicate=filter_func_rename, num_parallel_workers=4) | |||
| num_iter = 0 | |||
| ret_data = [] | |||
| for item in dataset_f.create_dict_iterator(): | |||
| num_iter += 1 | |||
| ret_data.append(item["col1"]) | |||
| assert num_iter == 55 | |||
| assert ret_data[0] == 9 | |||
| assert ret_data[54] == 63 | |||
| #test input_column | |||
| def filter_func_input_column1(col1, col2): | |||
| if col1[0] < 8: | |||
| return True | |||
| return False | |||
| def filter_func_input_column2(col1): | |||
| if col1[0] < 8: | |||
| return True | |||
| return False | |||
| def filter_func_input_column3(col1): | |||
| return True | |||
| # test with input_columns | |||
| def test_filter_by_generator_with_input_column(): | |||
| dataset = ds.GeneratorDataset(generator_mc(64), ["col1", "col2"]) | |||
| dataset_map = dataset.map( input_columns=["col1"], output_columns=["out1"] , operations=func_map_part) | |||
| dataset_f1 = dataset_map.filter(input_columns=["out1", "col2"], predicate=filter_func_input_column1, num_parallel_workers=4) | |||
| dataset_f2 = dataset_f1.filter(input_columns=["out1"], predicate=filter_func_input_column2, num_parallel_workers=4) | |||
| dataset_f3 = dataset_f2.filter(input_columns=["col2"], predicate=filter_func_input_column3, num_parallel_workers=4) | |||
| dataset_f4 = dataset_f3.filter(predicate=filter_func_input_column1, num_parallel_workers=4) | |||
| num_iter = 0 | |||
| ret_data = [] | |||
| for item in dataset_f4.create_dict_iterator(): | |||
| num_iter += 1 | |||
| ret_data.append(item["out1"]) | |||
| assert num_iter == 8 | |||
| assert ret_data[0] == 0 | |||
| assert ret_data[7] == 7 | |||
| #test kFilterPartial | |||
| def generator_mc_p0(maxid=20): | |||
| for i in range(maxid): | |||
| yield (np.array([i ]), np.array([i + 100])) | |||
| def generator_mc_p1(maxid=20): | |||
| for i in range(maxid): | |||
| yield (np.array([i + 200 ]), np.array([i + 300])) | |||
| def filter_func_Partial_0(col1, col2, col3, col4): | |||
| filter_data = [0,1,2,3,4, 11] | |||
| if col1[0] in filter_data: | |||
| return False | |||
| return True | |||
| # test with row_data_buffer > 1 | |||
| def test_filter_by_generator_Partial0(): | |||
| ds.config.load('../data/dataset/declient_filter.cfg') | |||
| dataset1= ds.GeneratorDataset(source = generator_mc_p0(), column_names = ["col1", "col2"]) | |||
| dataset2 = ds.GeneratorDataset(source = generator_mc_p1(), column_names = ["col3", "col4"]) | |||
| dataset_zip = ds.zip((dataset1, dataset2)) | |||
| dataset_f1 = dataset_zip.filter(predicate=filter_func_Partial_0, num_parallel_workers=2) | |||
| ret = [] | |||
| for item in dataset_f1.create_dict_iterator(): | |||
| ret.append(item["col1"]) | |||
| assert ret[0] == 5 | |||
| assert ret[6] == 12 | |||
| # test with row_data_buffer > 1 | |||
| def test_filter_by_generator_Partial1(): | |||
| ds.config.load('../data/dataset/declient_filter.cfg') | |||
| dataset1= ds.GeneratorDataset(source = generator_mc_p0(), column_names = ["col1", "col2"]) | |||
| dataset2 = ds.GeneratorDataset(source = generator_mc_p1(), column_names = ["col3", "col4"]) | |||
| dataset_zip = ds.zip((dataset1, dataset2)) | |||
| dataset_f1 = dataset_zip.filter(predicate=filter_func_Partial_0, num_parallel_workers=2) | |||
| dataset_map = dataset_f1.map( input_columns=["col1"], output_columns=["out1"] , operations=lambda x1: x1 + 400) | |||
| ret = [] | |||
| for item in dataset_map.create_dict_iterator(): | |||
| ret.append(item["out1"]) | |||
| assert ret[0] == 405 | |||
| assert ret[6] == 412 | |||
| # test with row_data_buffer > 1 | |||
| def test_filter_by_generator_Partial2(): | |||
| ds.config.load('../data/dataset/declient_filter.cfg') | |||
| dataset1= ds.GeneratorDataset(source = generator_mc_p0(), column_names = ["col1", "col2"]) | |||
| dataset2 = ds.GeneratorDataset(source = generator_mc_p1(), column_names = ["col3", "col4"]) | |||
| dataset1f = dataset1.filter( input_columns= ["col1"], predicate=lambda x: x not in [3,7,9], num_parallel_workers=2) | |||
| dataset2f = dataset2.filter( input_columns= ["col3"], predicate=lambda x: x not in [203,207,209], num_parallel_workers=2) | |||
| dataset_zip = ds.zip((dataset1f, dataset2f)) | |||
| dataset_map = dataset_zip.map( input_columns=["col1", "col3"], output_columns=["out1", "out3"] , operations=lambda x1,x3: (x1 + 400, x3+500)) | |||
| ret1 = [] | |||
| ret3 = [] | |||
| for item in dataset_map.create_dict_iterator(): | |||
| ret1.append(item["out1"]) | |||
| ret3.append(item["out3"]) | |||
| assert ret1[0] == 400 | |||
| assert ret1[6] == 408 | |||
| assert ret3[0] == 700 | |||
| assert ret3[6] == 708 | |||
| def filter_func_Partial(col1, col2): | |||
| if col1[0] % 3 == 0: | |||
| return True | |||
| return False | |||
| def generator_big(maxid=20): | |||
| for i in range(maxid): | |||
| yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]])) | |||
| # test with row_data_buffer > 1 | |||
| def test_filter_by_generator_Partial(): | |||
| ds.config.load('../data/dataset/declient_filter.cfg') | |||
| dataset = ds.GeneratorDataset(source = generator_mc(99), column_names = ["col1", "col2"]) | |||
| dataset_s = dataset.shuffle(4) | |||
| dataset_f1 = dataset_s.filter(input_columns=["col1", "col2"], predicate=filter_func_Partial, num_parallel_workers=1) | |||
| for item in dataset_f1.create_dict_iterator(): | |||
| assert item["col1"] % 3 == 0 | |||
| def filter_func_cifar(col1, col2): | |||
| if col2 % 3 == 0: | |||
| return True | |||
| return False | |||
| # test with cifar10 | |||
| def test_filte_case_dataset_cifar10(): | |||
| DATA_DIR_10 = "../data/dataset/testCifar10Data" | |||
| ds.config.load('../data/dataset/declient_filter.cfg') | |||
| dataset_c = ds.Cifar10Dataset(dataset_dir = DATA_DIR_10, num_samples = 100000, shuffle=False) | |||
| dataset_f1 = dataset_c.filter(input_columns=["image", "label"], predicate=filter_func_cifar, num_parallel_workers=1) | |||
| num_iter = 0 | |||
| for item in dataset_f1.create_dict_iterator(): | |||
| # in this example, each dictionary has keys "image" and "label" | |||
| assert item["label"] % 3 == 0 | |||
| # column id sort | |||
| def generator_sort1(maxid=20): | |||
| for i in range(maxid): | |||
| yield (np.array([i]), np.array([i + 100]), np.array([i + 200])) | |||
| def generator_sort2(maxid=20): | |||
| for i in range(maxid): | |||
| yield (np.array([i + 300]), np.array([i + 400]), np.array([i + 500])) | |||
| def filter_func_part_sort(col1, col2, col3, col4, col5, col6): | |||
| return True | |||
| def filter_func_map_sort(col1, col2, col3): | |||
| return (col1, col2, col3) | |||
| def test_filter_by_generator_with_map_all_sort(): | |||
| dataset1 = ds.GeneratorDataset(generator_sort1(10), ["col1", "col2", "col3"]) | |||
| dataset2 = ds.GeneratorDataset(generator_sort2(10), ["col4 ", "col5", "col6"]) | |||
| dataz = ds.zip((dataset1, dataset2)) | |||
| dataset_f = dataz.filter(predicate=filter_func_part_sort, num_parallel_workers=1) | |||
| num_iter = 0 | |||
| ret_data = [] | |||
| for item in dataset_f.create_dict_iterator(): | |||
| num_iter += 1 | |||
| ret_data.append(item) | |||
| assert num_iter == 10 | |||
| assert ret_data[0]["col1"] == 0 | |||
| assert ret_data[9]["col6"] == 509 | |||
| if __name__ == '__main__': | |||
| test_diff_predicate_func() | |||
| test_filte_case_dataset_cifar10() | |||
| test_filter_by_generator_Partial0() | |||
| test_filter_by_generator_Partial1() | |||
| test_filter_by_generator_Partial2() | |||
| test_filter_by_generator_with_batch() | |||
| test_filter_by_generator_with_batch_after() | |||
| test_filter_by_generator_with_input_column() | |||
| test_filter_by_generator_with_map_all_col() | |||
| test_filter_by_generator_with_map_all_sort() | |||
| test_filter_by_generator_with_map_part_col() | |||
| test_filter_by_generator_with_no() | |||
| test_filter_by_generator_with_rename() | |||
| test_filter_by_generator_with_repeat() | |||
| test_filter_by_generator_with_repeat_after() | |||
| test_filter_by_generator_with_shuffle() | |||
| test_filter_by_generator_with_shuffle_after() | |||
| test_filter_by_generator_with_zip() | |||
| test_filter_by_generator_with_zip_after() | |||
| test_filter_by_generator_Partial() | |||
| @@ -25,8 +25,8 @@ COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", | |||
| def check(project_columns): | |||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=COLUMNS) | |||
| data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=project_columns) | |||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=COLUMNS, shuffle=False) | |||
| data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=project_columns, shuffle=False) | |||
| for data_actual, data_expected in zip(data1.create_tuple_iterator(project_columns), data2.create_tuple_iterator()): | |||
| assert len(data_actual) == len(data_expected) | |||