Merge pull request !1157 from ms_yan/concat_datasettags/v0.3.0-alpha
| @@ -53,6 +53,7 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D | |||||
| {kRepeat, &DEPipeline::ParseRepeatOp}, | {kRepeat, &DEPipeline::ParseRepeatOp}, | ||||
| {kSkip, &DEPipeline::ParseSkipOp}, | {kSkip, &DEPipeline::ParseSkipOp}, | ||||
| {kZip, &DEPipeline::ParseZipOp}, | {kZip, &DEPipeline::ParseZipOp}, | ||||
| {kConcat, &DEPipeline::ParseConcatOp}, | |||||
| {kRename, &DEPipeline::ParseRenameOp}, | {kRename, &DEPipeline::ParseRenameOp}, | ||||
| {kDeviceQueue, &DEPipeline::ParseDeviceQueueOp}, | {kDeviceQueue, &DEPipeline::ParseDeviceQueueOp}, | ||||
| {kGenerator, &DEPipeline::ParseGeneratorOp}, | {kGenerator, &DEPipeline::ParseGeneratorOp}, | ||||
| @@ -757,6 +758,14 @@ Status DEPipeline::ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> * | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status DEPipeline::ParseConcatOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | |||||
| std::shared_ptr<ConcatOp::Builder> builder = std::make_shared<ConcatOp::Builder>(); | |||||
| std::shared_ptr<ConcatOp> op; | |||||
| RETURN_IF_NOT_OK(builder->Build(&op)); | |||||
| *ptr = op; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | ||||
| // Required arguments | // Required arguments | ||||
| std::shared_ptr<TFReaderOp::Builder> builder = std::make_shared<TFReaderOp::Builder>(); | std::shared_ptr<TFReaderOp::Builder> builder = std::make_shared<TFReaderOp::Builder>(); | ||||
| @@ -46,6 +46,7 @@ enum OpName { | |||||
| kSkip, | kSkip, | ||||
| kTake, | kTake, | ||||
| kZip, | kZip, | ||||
| kConcat, | |||||
| kMap, | kMap, | ||||
| kFilter, | kFilter, | ||||
| kDeviceQueue, | kDeviceQueue, | ||||
| @@ -127,6 +128,8 @@ class DEPipeline { | |||||
| Status ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | Status ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | ||||
| Status ParseConcatOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||||
| Status ParseDeviceQueueOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | Status ParseDeviceQueueOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | ||||
| Status ParseTFReaderOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | Status ParseTFReaderOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | ||||
| @@ -476,6 +476,7 @@ PYBIND11_MODULE(_c_dataengine, m) { | |||||
| .value("SKIP", OpName::kSkip) | .value("SKIP", OpName::kSkip) | ||||
| .value("TAKE", OpName::kTake) | .value("TAKE", OpName::kTake) | ||||
| .value("ZIP", OpName::kZip) | .value("ZIP", OpName::kZip) | ||||
| .value("CONCAT", OpName::kConcat) | |||||
| .value("MAP", OpName::kMap) | .value("MAP", OpName::kMap) | ||||
| .value("FILTER", OpName::kFilter) | .value("FILTER", OpName::kFilter) | ||||
| .value("DEVICEQUEUE", OpName::kDeviceQueue) | .value("DEVICEQUEUE", OpName::kDeviceQueue) | ||||
| @@ -42,6 +42,7 @@ | |||||
| #include "dataset/engine/datasetops/source/tf_reader_op.h" | #include "dataset/engine/datasetops/source/tf_reader_op.h" | ||||
| #include "dataset/engine/datasetops/take_op.h" | #include "dataset/engine/datasetops/take_op.h" | ||||
| #include "dataset/engine/datasetops/zip_op.h" | #include "dataset/engine/datasetops/zip_op.h" | ||||
| #include "dataset/engine/datasetops/concat_op.h" | |||||
| #include "dataset/engine/execution_tree.h" | #include "dataset/engine/execution_tree.h" | ||||
| #include "dataset/util/status.h" | #include "dataset/util/status.h" | ||||
| @@ -17,6 +17,7 @@ add_library(engine-datasetops OBJECT | |||||
| take_op.cc | take_op.cc | ||||
| shuffle_op.cc | shuffle_op.cc | ||||
| zip_op.cc | zip_op.cc | ||||
| concat_op.cc | |||||
| filter_op.cc | filter_op.cc | ||||
| ) | ) | ||||
| @@ -0,0 +1,145 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <iomanip> | |||||
| #include <utility> | |||||
| #include "common/utils.h" | |||||
| #include "dataset/core/config_manager.h" | |||||
| #include "dataset/engine/data_buffer.h" | |||||
| #include "dataset/engine/datasetops/concat_op.h" | |||||
| #include "dataset/engine/db_connector.h" | |||||
| #include "dataset/engine/execution_tree.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| // Builder constructor. Creates the builder object. | |||||
| ConcatOp::Builder::Builder() { | |||||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||||
| builder_op_connector_size_ = cfg->op_connector_size(); | |||||
| } | |||||
| // The builder "build" method creates the final object. | |||||
| Status ConcatOp::Builder::Build(std::shared_ptr<ConcatOp> *ptr) { | |||||
| *ptr = std::make_shared<ConcatOp>(builder_op_connector_size_); | |||||
| return Status::OK(); | |||||
| } | |||||
| // Constructor of the ConcatOp. | |||||
| ConcatOp::ConcatOp(int32_t op_connector_size) : PipelineOp(op_connector_size), children_num_(0) {} | |||||
| // A function that prints info about the Operator | |||||
| void ConcatOp::Print(std::ostream &out, bool show_all) const { | |||||
| // Always show the id and name as first line regardless if this is summary or detailed print | |||||
| out << "(" << std::setw(2) << operator_id_ << ") <ConcatOp>:"; | |||||
| if (!show_all) { | |||||
| // Call the super class for displaying any common 1-liner info | |||||
| PipelineOp::Print(out, show_all); | |||||
| // Then show any custom derived-internal 1-liner info for this op | |||||
| out << "\n"; | |||||
| } else { | |||||
| // Call the super class for displaying any common detailed info | |||||
| PipelineOp::Print(out, show_all); | |||||
| // Then show any custom derived-internal stuff | |||||
| out << "\nDatasets: " << children_num_ << "\n\n"; | |||||
| } | |||||
| } | |||||
| // Main entry point for Concat | |||||
| Status ConcatOp::operator()() { | |||||
| // The children_num_ parameter needs to be put here | |||||
| children_num_ = static_cast<int32_t>(child_.size()); | |||||
| TaskManager::FindMe()->Post(); | |||||
| std::unique_ptr<DataBuffer> buf; | |||||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf)); | |||||
| // Obtain columns_name_id_map from child_[0] | |||||
| column_name_id_map_ = child_[0]->column_name_id_map(); | |||||
| if (column_name_id_map_.empty()) { | |||||
| RETURN_STATUS_UNEXPECTED("Child column name map cannot be empty!"); | |||||
| } | |||||
| int eof_count = 0; | |||||
| while (eof_count != children_num_) { | |||||
| for (int i = 0; i < children_num_; i++) { | |||||
| // 1. Throw the eof buffer when meet it | |||||
| if (buf->eof() || buf->eoe()) { | |||||
| RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf)); | |||||
| } | |||||
| // 2. Do varification as for column name, column data type and rank of column data | |||||
| RETURN_IF_NOT_OK(Verify(i, buf)); | |||||
| // 3. Put the data into output_connector | |||||
| while (!buf->eoe() && !buf->eof()) { | |||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buf))); | |||||
| RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf)); | |||||
| } | |||||
| // 4. Throw the eoe buffer when meet it | |||||
| if (buf->eoe() && (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat))) { | |||||
| RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf)); | |||||
| } | |||||
| // 5. Add eoe buffer after get buffer from all child | |||||
| if (i == (children_num_ - 1)) { | |||||
| auto eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); | |||||
| } | |||||
| if (buf->eof()) { | |||||
| eof_count++; | |||||
| } | |||||
| } | |||||
| } | |||||
| // 6. Add eof buffer in the end manually | |||||
| MS_LOG(DEBUG) << "Add the eof buffer manualy in the end."; | |||||
| auto eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF); | |||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status ConcatOp::Verify(int32_t id, const std::unique_ptr<DataBuffer> &buf) { | |||||
| TensorRow new_row; | |||||
| buf->GetRow(0, &new_row); | |||||
| if (id == 0) { | |||||
| // Obtain the column name, data type and data rank in child[0] | |||||
| column_name_id_ = child_[id]->column_name_id_map(); | |||||
| for (auto item : new_row) { | |||||
| data_type_.push_back(item->type()); | |||||
| data_rank_.push_back(item->Rank()); | |||||
| } | |||||
| } else { | |||||
| // Compare the column name, data type and data rank with these in child[0] | |||||
| if (child_[id]->column_name_id_map() != column_name_id_) { | |||||
| RETURN_STATUS_UNEXPECTED("The column name or column order is not the same with previous dataset."); | |||||
| } | |||||
| int32_t index = 0; | |||||
| for (auto item : new_row) { | |||||
| if ((item->type() != data_type_[index]) || item->Rank() != data_rank_[index++]) { | |||||
| RETURN_STATUS_UNEXPECTED("The data type or data rank is not the same with previous dataset."); | |||||
| } | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status ConcatOp::PrepareNodePostAction() { | |||||
| RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction()); | |||||
| tree_->AddToRepeatStack(shared_from_this()); | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,95 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_ | |||||
| #define DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <unordered_map> | |||||
| #include <vector> | |||||
| #include "dataset/engine/datasetops/pipeline_op.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| class ConcatOp : public PipelineOp { | |||||
| public: | |||||
| // The nested builder class inside of the ConcatOp is used to help manage all of the arguments | |||||
| // for constructing it. This Concat op is very simple though, so this builder is really just | |||||
| // provided for a consistent look and feel for creators of Dataset operators overall. | |||||
| class Builder { | |||||
| public: | |||||
| // Builder constructor. Creates the builder object. | |||||
| // @note No default args | |||||
| // @return This is a constructor. | |||||
| Builder(); | |||||
| // Default destructor | |||||
| ~Builder() = default; | |||||
| // The builder "build" method creates the final object. | |||||
| // @return shared_ptr to the new StorageOp object | |||||
| Status Build(std::shared_ptr<ConcatOp> *); | |||||
| private: | |||||
| int32_t builder_op_connector_size_; | |||||
| }; | |||||
| // Constructor of the ConcatOp. | |||||
| // @note The builder class should be used to call it | |||||
| // @param op_connector_size - connector size | |||||
| explicit ConcatOp(int32_t op_connector_size); | |||||
| // Destructor | |||||
| ~ConcatOp() = default; | |||||
| // A print method typically used for debugging | |||||
| // @param out - The output stream to write output to | |||||
| // @param show_all - A bool to control if you want to show all info or just a summary | |||||
| void Print(std::ostream &out, bool show_all) const override; | |||||
| // << Stream output operator overload | |||||
| // @notes This allows you to write the debug print info using stream operators | |||||
| // @param out - reference to the output stream being overloaded | |||||
| // @param ro - reference to the ConcatOp to display | |||||
| // @return - the output stream must be returned | |||||
| friend std::ostream &operator<<(std::ostream &out, const ConcatOp &ro) { | |||||
| ro.Print(out, false); | |||||
| return out; | |||||
| } | |||||
| // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will | |||||
| // provide the master loop that drives the logic for performing the work | |||||
| // @return Status - The error code return | |||||
| Status operator()() override; | |||||
| // During tree prepare phase, operators may have specific post-operations to perform depending on | |||||
| // their role. | |||||
| // @notes Derived versions of this function should always call it's superclass version first | |||||
| // before providing their own implementations. | |||||
| Status PrepareNodePostAction() override; | |||||
| private: | |||||
| Status Verify(int32_t id, const std::unique_ptr<DataBuffer> &buf); | |||||
| int32_t children_num_; // The num of child of parent node. | |||||
| std::unordered_map<std::string, int32_t> column_name_id_; // Mapping between col index and col name | |||||
| std::vector<DataType> data_type_; | |||||
| std::vector<dsize_t> data_rank_; | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_ | |||||
| @@ -44,7 +44,7 @@ from .validators import check, check_batch, check_shuffle, check_map, check_filt | |||||
| check_rename, \ | check_rename, \ | ||||
| check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ | check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ | ||||
| check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \ | check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \ | ||||
| check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset | |||||
| check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat | |||||
| from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist | from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist | ||||
| try: | try: | ||||
| @@ -147,6 +147,9 @@ class Dataset: | |||||
| self._repeat_count = None | self._repeat_count = None | ||||
| self._sync = False | self._sync = False | ||||
| def __add__(self, datasets): | |||||
| return self.concat(datasets) | |||||
| def get_args(self): | def get_args(self): | ||||
| """ | """ | ||||
| Returns attributes (member variables) related to the current class. | Returns attributes (member variables) related to the current class. | ||||
| @@ -560,6 +563,37 @@ class Dataset: | |||||
| raise TypeError("The zip function %s type error!" % (datasets)) | raise TypeError("The zip function %s type error!" % (datasets)) | ||||
| return ZipDataset(datasets) | return ZipDataset(datasets) | ||||
| @check_concat | |||||
| def concat(self, datasets): | |||||
| """ | |||||
| Concat the datasets in the input list of datasets, supported using "+" to reload concat operation. | |||||
| Note: | |||||
| The column name,column data type and rank of column data should be the same in input datasets. | |||||
| Args: | |||||
| datasets (list or class Dataset): A list of datasets or a single class Dataset | |||||
| to be concated together with this dataset. | |||||
| Returns: | |||||
| ConcatDataset, dataset concated. | |||||
| Examples: | |||||
| >>> import mindspore.dataset as ds | |||||
| >>> # ds1 and ds2 are instances of Dataset object | |||||
| >>> # creates a dataset by concating ds1 and ds2 with "+" operation | |||||
| >>> data1 = ds1 + ds2 | |||||
| >>> # creates a dataset by concating ds1 and ds2 with concat operation | |||||
| >>> data1 = ds1.concat(ds2) | |||||
| """ | |||||
| if isinstance(datasets, Dataset): | |||||
| datasets = [self] + [datasets] | |||||
| elif isinstance(datasets, list): | |||||
| datasets = [self] + datasets | |||||
| else: | |||||
| raise TypeError("The concat_dataset function %s type error!" % (datasets)) | |||||
| return ConcatDataset(datasets) | |||||
| @check_rename | @check_rename | ||||
| def rename(self, input_columns, output_columns): | def rename(self, input_columns, output_columns): | ||||
| """ | """ | ||||
| @@ -1658,6 +1692,39 @@ class ZipDataset(DatasetOp): | |||||
| return args | return args | ||||
| class ConcatDataset(DatasetOp): | |||||
| """ | |||||
| The result of applying concat dataset operator to the input Dataset. | |||||
| Args: | |||||
| datasets (list): A list of datasets to be concated together. | |||||
| Raises: | |||||
| TypeError: If dataset is not an instance of Dataset. | |||||
| """ | |||||
| def __init__(self, datasets): | |||||
| super().__init__() | |||||
| for dataset in datasets: | |||||
| if not isinstance(dataset, Dataset): | |||||
| raise TypeError("The parameter %s of concat has type error!" % (dataset)) | |||||
| self.datasets = datasets | |||||
| for data in datasets: | |||||
| self.input.append(data) | |||||
| data.output.append(self) | |||||
| def get_dataset_size(self): | |||||
| """ | |||||
| Get the number of batches in an epoch. | |||||
| Return: | |||||
| Number, number of batches. | |||||
| """ | |||||
| children_sizes = [c.get_dataset_size() for c in self.input] | |||||
| dataset_size = np.sum(children_sizes) | |||||
| return dataset_size | |||||
| class RenameDataset(DatasetOp): | class RenameDataset(DatasetOp): | ||||
| """ | """ | ||||
| The result of applying Rename operator to the input Dataset. | The result of applying Rename operator to the input Dataset. | ||||
| @@ -156,6 +156,8 @@ class Iterator: | |||||
| op_type = OpName.BARRIER | op_type = OpName.BARRIER | ||||
| elif isinstance(dataset, de.ZipDataset): | elif isinstance(dataset, de.ZipDataset): | ||||
| op_type = OpName.ZIP | op_type = OpName.ZIP | ||||
| elif isinstance(dataset, de.ConcatDataset): | |||||
| op_type = OpName.CONCAT | |||||
| elif isinstance(dataset, de.MapDataset): | elif isinstance(dataset, de.MapDataset): | ||||
| op_type = OpName.MAP | op_type = OpName.MAP | ||||
| elif isinstance(dataset, de.FilterDataset): | elif isinstance(dataset, de.FilterDataset): | ||||
| @@ -335,6 +335,10 @@ def create_node(node): | |||||
| # Create ZipDataset instance, giving dummy input dataset that will be overrided in the caller. | # Create ZipDataset instance, giving dummy input dataset that will be overrided in the caller. | ||||
| pyobj = de.ZipDataset((de.Dataset(), de.Dataset())) | pyobj = de.ZipDataset((de.Dataset(), de.Dataset())) | ||||
| elif dataset_op == 'ConcatDataset': | |||||
| # Create ConcatDataset instance, giving dummy input dataset that will be overrided in the caller. | |||||
| pyobj = de.ConcatDataset((de.Dataset(), de.Dataset())) | |||||
| elif dataset_op == 'RenameDataset': | elif dataset_op == 'RenameDataset': | ||||
| pyobj = de.Dataset().rename(node['input_columns'], node['output_columns']) | pyobj = de.Dataset().rename(node['input_columns'], node['output_columns']) | ||||
| @@ -902,6 +902,26 @@ def check_zip_dataset(method): | |||||
| return new_method | return new_method | ||||
| def check_concat(method): | |||||
| """check the input arguments of concat_dataset method in `Dataset`.""" | |||||
| @wraps(method) | |||||
| def new_method(*args, **kwargs): | |||||
| param_dict = make_param_dict(method, args, kwargs) | |||||
| # check datasets; required argument | |||||
| ds = param_dict.get("datasets") | |||||
| if ds is None: | |||||
| raise ValueError("datasets is not provided.") | |||||
| if not isinstance(ds, (list, datasets.Dataset)): | |||||
| raise ValueError("datasets is not list or of type Dataset.") | |||||
| return method(*args, **kwargs) | |||||
| return new_method | |||||
| def check_rename(method): | def check_rename(method): | ||||
| """check the input arguments of rename.""" | """check the input arguments of rename.""" | ||||
| @@ -66,6 +66,7 @@ SET(DE_UT_SRCS | |||||
| take_op_test.cc | take_op_test.cc | ||||
| text_file_op_test.cc | text_file_op_test.cc | ||||
| filter_op_test.cc | filter_op_test.cc | ||||
| concat_op_test.cc | |||||
| ) | ) | ||||
| add_executable(de_ut_tests ${DE_UT_SRCS}) | add_executable(de_ut_tests ${DE_UT_SRCS}) | ||||
| @@ -0,0 +1,125 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <iostream> | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "common/common.h" | |||||
| #include "common/utils.h" | |||||
| #include "dataset/core/client.h" | |||||
| #include "gtest/gtest.h" | |||||
| #include "utils/log_adapter.h" | |||||
| namespace common = mindspore::common; | |||||
| using namespace mindspore::dataset; | |||||
| using mindspore::MsLogLevel::INFO; | |||||
| using mindspore::ExceptionType::NoExceptionType; | |||||
| using mindspore::LogStream; | |||||
| class MindDataTestConcatOp : public UT::DatasetOpTesting {}; | |||||
| TEST_F(MindDataTestConcatOp, TestConcatProject) { | |||||
| /* Tree: | |||||
| * | |||||
| * OpId(2) ConcatOp | |||||
| * / \ | |||||
| * OpId(0) TFReaderOp OpId(1) TFReaderOp | |||||
| * | |||||
| * Start with an empty execution tree | |||||
| */ | |||||
| MS_LOG(INFO) << "UT test TestConcatProject."; | |||||
| auto my_tree = std::make_shared<ExecutionTree>(); | |||||
| std::string dataset_path; | |||||
| dataset_path = datasets_root_path_ + "/testTFTestAllTypes/test.data"; | |||||
| // TFReaderOp1 | |||||
| std::shared_ptr<TFReaderOp> my_tfreader_op1; | |||||
| TFReaderOp::Builder builder1; | |||||
| builder1.SetDatasetFilesList({dataset_path}) | |||||
| .SetRowsPerBuffer(16) | |||||
| .SetWorkerConnectorSize(16) | |||||
| .SetNumWorkers(16); | |||||
| std::unique_ptr<DataSchema> schema1 = std::make_unique<DataSchema>(); | |||||
| schema1->LoadSchemaFile(datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json", {}); | |||||
| builder1.SetDataSchema(std::move(schema1)); | |||||
| Status rc = builder1.Build(&my_tfreader_op1); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rc = my_tree->AssociateNode(my_tfreader_op1); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| // TFReaderOp2 | |||||
| std::shared_ptr<TFReaderOp> my_tfreader_op2; | |||||
| TFReaderOp::Builder builder2; | |||||
| builder2.SetDatasetFilesList({dataset_path}) | |||||
| .SetRowsPerBuffer(16) | |||||
| .SetWorkerConnectorSize(16) | |||||
| .SetNumWorkers(16); | |||||
| std::unique_ptr<DataSchema> schema2 = std::make_unique<DataSchema>(); | |||||
| schema2->LoadSchemaFile(datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json", {}); | |||||
| builder2.SetDataSchema(std::move(schema2)); | |||||
| rc = builder2.Build(&my_tfreader_op2); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rc = my_tree->AssociateNode(my_tfreader_op2); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| // Creating ConcatOp | |||||
| std::shared_ptr<ConcatOp> concat_op; | |||||
| rc = ConcatOp::Builder().Build(&concat_op); | |||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| rc = my_tree->AssociateNode(concat_op); | |||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| rc = concat_op->AddChild(std::move(my_tfreader_op1)); | |||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| rc = concat_op->AddChild(std::move(my_tfreader_op2)); | |||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| rc = my_tree->AssignRoot(concat_op); | |||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| rc = my_tree->Prepare(); | |||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| // Launch the tree execution to kick off threads and start running the pipeline | |||||
| MS_LOG(INFO) << "Launching my tree."; | |||||
| rc = my_tree->Launch(); | |||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| // Simulate a parse of data from our pipeline. | |||||
| std::shared_ptr<DatasetOp> rootNode = my_tree->root(); | |||||
| DatasetIterator di(my_tree); | |||||
| TensorRow tensor_list; | |||||
| rc = di.FetchNextTensorRow(&tensor_list); | |||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| int row_count = 0; | |||||
| while (!tensor_list.empty()) { | |||||
| MS_LOG(INFO) << "Row display for row #: " << row_count << "."; | |||||
| // Display the tensor by calling the printer on it | |||||
| for (int i = 0; i < tensor_list.size(); i++) { | |||||
| std::ostringstream ss; | |||||
| ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl; | |||||
| MS_LOG(INFO) << "Tensor print: " << common::SafeCStr(ss.str()) << "."; | |||||
| } | |||||
| rc = di.FetchNextTensorRow(&tensor_list); | |||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| row_count++; | |||||
| } | |||||
| ASSERT_EQ(row_count, 24); // Should be 24 rows fetched | |||||
| } | |||||
| @@ -0,0 +1,377 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| import mindspore.dataset as ds | |||||
| import mindspore.dataset.transforms.vision.py_transforms as F | |||||
| import mindspore.dataset.transforms.c_transforms as C | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore import log as logger | |||||
| import numpy as np | |||||
| # In generator dataset: Number of rows is 3, its value is 0, 1, 2 | |||||
| def generator(): | |||||
| for i in range(3): | |||||
| yield np.array([i]), | |||||
| # In generator_10 dataset: Number of rows is 7, its value is 3, 4, 5 ... 10 | |||||
| def generator_10(): | |||||
| for i in range(3, 10): | |||||
| yield np.array([i]), | |||||
| # In generator_20 dataset: Number of rows is 10, its value is 10, 11, 12 ... 20 | |||||
| def generator_20(): | |||||
| for i in range(10, 20): | |||||
| yield np.array([i]), | |||||
| def test_concat_01(): | |||||
| """ | |||||
| Test concat: test concat 2 datasets that have the same column name and data type | |||||
| """ | |||||
| logger.info("test_concat_01") | |||||
| data1 = ds.GeneratorDataset(generator, ["col1"]) | |||||
| data2 = ds.GeneratorDataset(generator_10, ["col1"]) | |||||
| data3 = data1 + data2 | |||||
| # Here i refers to index, d refers to data element | |||||
| for i, d in enumerate(data3): | |||||
| logger.info("data: %i", d[0][0]) | |||||
| assert i == d[0][0] | |||||
| assert sum([1 for _ in data3]) == 10 | |||||
| def test_concat_02(): | |||||
| """ | |||||
| Test concat: test concat 2 datasets using concat operation not "+" operation | |||||
| """ | |||||
| logger.info("test_concat_02") | |||||
| data1 = ds.GeneratorDataset(generator, ["col1"]) | |||||
| data2 = ds.GeneratorDataset(generator_10, ["col1"]) | |||||
| data3 = data1.concat(data2) | |||||
| # Here i refers to index, d refers to data element | |||||
| for i, d in enumerate(data3): | |||||
| logger.info("data: %i", d[0][0]) | |||||
| assert i == d[0][0] | |||||
| assert sum([1 for _ in data3]) == 10 | |||||
| def test_concat_03(): | |||||
| """ | |||||
| Test concat: test concat dataset that has different column | |||||
| """ | |||||
| logger.info("test_concat_03") | |||||
| data1 = ds.GeneratorDataset(generator, ["col1"]) | |||||
| data2 = ds.GeneratorDataset(generator_10, ["col2"]) | |||||
| data3 = data1 + data2 | |||||
| try: | |||||
| for i, d in enumerate(data3): | |||||
| pass | |||||
| assert False | |||||
| except RuntimeError: | |||||
| pass | |||||
| def test_concat_04(): | |||||
| """ | |||||
| Test concat: test concat dataset that has different rank | |||||
| """ | |||||
| logger.info("test_concat_04") | |||||
| data1 = ds.GeneratorDataset(generator, ["col1"]) | |||||
| data2 = ds.GeneratorDataset(generator_10, ["col2"]) | |||||
| data2 = data2.batch(3) | |||||
| data3 = data1 + data2 | |||||
| try: | |||||
| for i, d in enumerate(data3): | |||||
| pass | |||||
| assert False | |||||
| except RuntimeError: | |||||
| pass | |||||
| def test_concat_05(): | |||||
| """ | |||||
| Test concat: test concat dataset that has different data type | |||||
| """ | |||||
| logger.info("test_concat_05") | |||||
| data1 = ds.GeneratorDataset(generator, ["col1"]) | |||||
| data2 = ds.GeneratorDataset(generator_10, ["col1"]) | |||||
| type_cast_op = C.TypeCast(mstype.float32) | |||||
| data1 = data1.map(input_columns=["col1"], operations=type_cast_op) | |||||
| data3 = data1 + data2 | |||||
| try: | |||||
| for i, d in enumerate(data3): | |||||
| pass | |||||
| assert False | |||||
| except RuntimeError: | |||||
| pass | |||||
| def test_concat_06(): | |||||
| """ | |||||
| Test concat: test concat muti datasets in one time | |||||
| """ | |||||
| logger.info("test_concat_06") | |||||
| data1 = ds.GeneratorDataset(generator, ["col1"]) | |||||
| data2 = ds.GeneratorDataset(generator_10, ["col1"]) | |||||
| data3 = ds.GeneratorDataset(generator_20, ["col1"]) | |||||
| dataset = data1 + data2 + data3 | |||||
| # Here i refers to index, d refers to data element | |||||
| for i, d in enumerate(dataset): | |||||
| logger.info("data: %i", d[0][0]) | |||||
| assert i == d[0][0] | |||||
| assert sum([1 for _ in dataset]) == 20 | |||||
| def test_concat_07(): | |||||
| """ | |||||
| Test concat: test concat one dataset with multi datasets (datasets list) | |||||
| """ | |||||
| logger.info("test_concat_07") | |||||
| data1 = ds.GeneratorDataset(generator, ["col1"]) | |||||
| data2 = ds.GeneratorDataset(generator_10, ["col1"]) | |||||
| data3 = ds.GeneratorDataset(generator_20, ["col1"]) | |||||
| dataset = [data2] + [data3] | |||||
| data4 = data1 + dataset | |||||
| # Here i refers to index, d refers to data element | |||||
| for i, d in enumerate(data4): | |||||
| logger.info("data: %i", d[0][0]) | |||||
| assert i == d[0][0] | |||||
| assert sum([1 for _ in data4]) == 20 | |||||
| def test_concat_08(): | |||||
| """ | |||||
| Test concat: test concat 2 datasets, and then repeat | |||||
| """ | |||||
| logger.info("test_concat_08") | |||||
| data1 = ds.GeneratorDataset(generator, ["col1"]) | |||||
| data2 = ds.GeneratorDataset(generator_10, ["col1"]) | |||||
| data3 = data1 + data2 | |||||
| data3 = data3.repeat(2) | |||||
| # Here i refers to index, d refers to data element | |||||
| for i, d in enumerate(data3): | |||||
| logger.info("data: %i", d[0][0]) | |||||
| assert i % 10 == d[0][0] | |||||
| assert sum([1 for _ in data3]) == 20 | |||||
| def test_concat_09(): | |||||
| """ | |||||
| Test concat: test concat 2 datasets, both of them have been repeat before | |||||
| """ | |||||
| logger.info("test_concat_09") | |||||
| data1 = ds.GeneratorDataset(generator, ["col1"]) | |||||
| data2 = ds.GeneratorDataset(generator_10, ["col1"]) | |||||
| data1 = data1.repeat(2) | |||||
| data2 = data2.repeat(2) | |||||
| data3 = data1 + data2 | |||||
| res = [0, 1, 2, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 3, 4, 5, 6, 7, 8, 9] | |||||
| # Here i refers to index, d refers to data element | |||||
| for i, d in enumerate(data3): | |||||
| logger.info("data: %i", d[0][0]) | |||||
| assert res[i] == d[0][0] | |||||
| assert sum([1 for _ in data3]) == 20 | |||||
| def test_concat_10(): | |||||
| """ | |||||
| Test concat: test concat 2 datasets, one of them have repeat before | |||||
| """ | |||||
| logger.info("test_concat_10") | |||||
| data1 = ds.GeneratorDataset(generator, ["col1"]) | |||||
| data2 = ds.GeneratorDataset(generator_10, ["col1"]) | |||||
| data1 = data1.repeat(2) | |||||
| data3 = data1 + data2 | |||||
| res = [0, 1, 2, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9] | |||||
| # Here i refers to index, d refers to data element | |||||
| for i, d in enumerate(data3): | |||||
| logger.info("data: %i", d[0][0]) | |||||
| assert res[i] == d[0][0] | |||||
| assert sum([1 for _ in data3]) == 13 | |||||
| def test_concat_11(): | |||||
| """ | |||||
| Test concat: test dataset batch then concat | |||||
| """ | |||||
| logger.info("test_concat_11") | |||||
| data1 = ds.GeneratorDataset(generator, ["col1"]) | |||||
| data2 = ds.GeneratorDataset(generator_20, ["col1"]) | |||||
| data1 = data1.batch(3) | |||||
| data2 = data2.batch(5) | |||||
| data3 = data1 + data2 | |||||
| res = [0, 10, 15, 20] | |||||
| # Here i refers to index, d refers to data element | |||||
| for i, d in enumerate(data3): | |||||
| logger.info("data: %i", d[0][0]) | |||||
| assert res[i] == d[0][0] | |||||
| assert sum([1 for _ in data3]) == 3 | |||||
| def test_concat_12(): | |||||
| """ | |||||
| Test concat: test dataset concat then shuffle | |||||
| """ | |||||
| logger.info("test_concat_12") | |||||
| data1 = ds.GeneratorDataset(generator, ["col1"]) | |||||
| data2 = ds.GeneratorDataset(generator_10, ["col1"]) | |||||
| data1.set_dataset_size(3) | |||||
| data2.set_dataset_size(7) | |||||
| data3 = data1 + data2 | |||||
| res = [8, 6, 2, 5, 0, 4, 9, 3, 7, 1] | |||||
| ds.config.set_seed(1) | |||||
| assert data3.get_dataset_size() == 10 | |||||
| data3 = data3.shuffle(buffer_size=10) | |||||
| # Here i refers to index, d refers to data element | |||||
| for i, d in enumerate(data3): | |||||
| logger.info("data: %i", d[0][0]) | |||||
| assert res[i] == d[0][0] | |||||
| assert sum([1 for _ in data3]) == 10 | |||||
| def test_concat_13(): | |||||
| """ | |||||
| Test concat: test dataset batch then shuffle and concat | |||||
| """ | |||||
| logger.info("test_concat_13") | |||||
| data1 = ds.GeneratorDataset(generator, ["col1"]) | |||||
| data2 = ds.GeneratorDataset(generator_20, ["col1"]) | |||||
| data1.set_dataset_size(3) | |||||
| data2.set_dataset_size(10) | |||||
| data1 = data1.batch(3) | |||||
| data2 = data2.batch(5) | |||||
| data3 = data1 + data2 | |||||
| res = [15, 0, 10] | |||||
| ds.config.set_seed(1) | |||||
| assert data3.get_dataset_size() == 3 | |||||
| data3 = data3.shuffle(buffer_size=int(data3.get_dataset_size())) | |||||
| # Here i refers to index, d refers to data element | |||||
| for i, d in enumerate(data3): | |||||
| logger.info("data: %i", d[0][0]) | |||||
| assert res[i] == d[0][0] | |||||
| assert sum([1 for _ in data3]) == 3 | |||||
| def test_concat_14(): | |||||
| """ | |||||
| Test concat: create dataset with different dataset folder, and do diffrent operation then concat | |||||
| """ | |||||
| logger.info("test_concat_14") | |||||
| DATA_DIR = "../data/dataset/testPK/data" | |||||
| DATA_DIR2 = "../data/dataset/testImageNetData/train/" | |||||
| data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_samples=3) | |||||
| data2 = ds.ImageFolderDatasetV2(DATA_DIR2, num_samples=2) | |||||
| transforms1 = F.ComposeOp([F.Decode(), | |||||
| F.Resize((224,224)), | |||||
| F.ToTensor()]) | |||||
| data1 = data1.map(input_columns=["image"], operations=transforms1()) | |||||
| data2 = data2.map(input_columns=["image"], operations=transforms1()) | |||||
| data3 = data1 + data2 | |||||
| expected, output = [], [] | |||||
| for d in data1: | |||||
| expected.append(d[0]) | |||||
| for d in data2: | |||||
| expected.append(d[0]) | |||||
| for d in data3: | |||||
| output.append(d[0]) | |||||
| assert len(expected) == len(output) | |||||
| np.array_equal(np.array(output), np.array(expected)) | |||||
| assert sum([1 for _ in data3]) == 5 | |||||
| assert data3.get_dataset_size() == 5 | |||||
| def test_concat_15(): | |||||
| """ | |||||
| Test concat: create dataset with different format of dataset file, and then concat | |||||
| """ | |||||
| logger.info("test_concat_15") | |||||
| DATA_DIR = "../data/dataset/testPK/data" | |||||
| DATA_DIR2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | |||||
| data1 = ds.ImageFolderDatasetV2(DATA_DIR) | |||||
| data2 = ds.TFRecordDataset(DATA_DIR2, columns_list=["image"]) | |||||
| data1 = data1.project(["image"]) | |||||
| data3 = data1 + data2 | |||||
| assert sum([1 for _ in data3]) == 47 | |||||
| if __name__ == "__main__": | |||||
| test_concat_01() | |||||
| test_concat_02() | |||||
| test_concat_03() | |||||
| test_concat_04() | |||||
| test_concat_05() | |||||
| test_concat_06() | |||||
| test_concat_07() | |||||
| test_concat_08() | |||||
| test_concat_09() | |||||
| test_concat_10() | |||||
| test_concat_11() | |||||
| test_concat_12() | |||||
| test_concat_13() | |||||
| test_concat_14() | |||||
| test_concat_15() | |||||