Merge pull request !1157 from ms_yan/concat_datasettags/v0.3.0-alpha
| @@ -53,6 +53,7 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D | |||
| {kRepeat, &DEPipeline::ParseRepeatOp}, | |||
| {kSkip, &DEPipeline::ParseSkipOp}, | |||
| {kZip, &DEPipeline::ParseZipOp}, | |||
| {kConcat, &DEPipeline::ParseConcatOp}, | |||
| {kRename, &DEPipeline::ParseRenameOp}, | |||
| {kDeviceQueue, &DEPipeline::ParseDeviceQueueOp}, | |||
| {kGenerator, &DEPipeline::ParseGeneratorOp}, | |||
| @@ -757,6 +758,14 @@ Status DEPipeline::ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> * | |||
| return Status::OK(); | |||
| } | |||
| Status DEPipeline::ParseConcatOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | |||
| std::shared_ptr<ConcatOp::Builder> builder = std::make_shared<ConcatOp::Builder>(); | |||
| std::shared_ptr<ConcatOp> op; | |||
| RETURN_IF_NOT_OK(builder->Build(&op)); | |||
| *ptr = op; | |||
| return Status::OK(); | |||
| } | |||
| Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | |||
| // Required arguments | |||
| std::shared_ptr<TFReaderOp::Builder> builder = std::make_shared<TFReaderOp::Builder>(); | |||
| @@ -46,6 +46,7 @@ enum OpName { | |||
| kSkip, | |||
| kTake, | |||
| kZip, | |||
| kConcat, | |||
| kMap, | |||
| kFilter, | |||
| kDeviceQueue, | |||
| @@ -127,6 +128,8 @@ class DEPipeline { | |||
| Status ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| Status ParseConcatOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| Status ParseDeviceQueueOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| Status ParseTFReaderOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| @@ -476,6 +476,7 @@ PYBIND11_MODULE(_c_dataengine, m) { | |||
| .value("SKIP", OpName::kSkip) | |||
| .value("TAKE", OpName::kTake) | |||
| .value("ZIP", OpName::kZip) | |||
| .value("CONCAT", OpName::kConcat) | |||
| .value("MAP", OpName::kMap) | |||
| .value("FILTER", OpName::kFilter) | |||
| .value("DEVICEQUEUE", OpName::kDeviceQueue) | |||
| @@ -42,6 +42,7 @@ | |||
| #include "dataset/engine/datasetops/source/tf_reader_op.h" | |||
| #include "dataset/engine/datasetops/take_op.h" | |||
| #include "dataset/engine/datasetops/zip_op.h" | |||
| #include "dataset/engine/datasetops/concat_op.h" | |||
| #include "dataset/engine/execution_tree.h" | |||
| #include "dataset/util/status.h" | |||
| @@ -17,6 +17,7 @@ add_library(engine-datasetops OBJECT | |||
| take_op.cc | |||
| shuffle_op.cc | |||
| zip_op.cc | |||
| concat_op.cc | |||
| filter_op.cc | |||
| ) | |||
| @@ -0,0 +1,145 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <iomanip> | |||
| #include <utility> | |||
| #include "common/utils.h" | |||
| #include "dataset/core/config_manager.h" | |||
| #include "dataset/engine/data_buffer.h" | |||
| #include "dataset/engine/datasetops/concat_op.h" | |||
| #include "dataset/engine/db_connector.h" | |||
| #include "dataset/engine/execution_tree.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Builder constructor. Creates the builder object. | |||
| ConcatOp::Builder::Builder() { | |||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||
| builder_op_connector_size_ = cfg->op_connector_size(); | |||
| } | |||
| // The builder "build" method creates the final object. | |||
| Status ConcatOp::Builder::Build(std::shared_ptr<ConcatOp> *ptr) { | |||
| *ptr = std::make_shared<ConcatOp>(builder_op_connector_size_); | |||
| return Status::OK(); | |||
| } | |||
| // Constructor of the ConcatOp. | |||
| ConcatOp::ConcatOp(int32_t op_connector_size) : PipelineOp(op_connector_size), children_num_(0) {} | |||
| // A function that prints info about the Operator | |||
| void ConcatOp::Print(std::ostream &out, bool show_all) const { | |||
| // Always show the id and name as first line regardless if this is summary or detailed print | |||
| out << "(" << std::setw(2) << operator_id_ << ") <ConcatOp>:"; | |||
| if (!show_all) { | |||
| // Call the super class for displaying any common 1-liner info | |||
| PipelineOp::Print(out, show_all); | |||
| // Then show any custom derived-internal 1-liner info for this op | |||
| out << "\n"; | |||
| } else { | |||
| // Call the super class for displaying any common detailed info | |||
| PipelineOp::Print(out, show_all); | |||
| // Then show any custom derived-internal stuff | |||
| out << "\nDatasets: " << children_num_ << "\n\n"; | |||
| } | |||
| } | |||
| // Main entry point for Concat | |||
| Status ConcatOp::operator()() { | |||
| // The children_num_ parameter needs to be put here | |||
| children_num_ = static_cast<int32_t>(child_.size()); | |||
| TaskManager::FindMe()->Post(); | |||
| std::unique_ptr<DataBuffer> buf; | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf)); | |||
| // Obtain columns_name_id_map from child_[0] | |||
| column_name_id_map_ = child_[0]->column_name_id_map(); | |||
| if (column_name_id_map_.empty()) { | |||
| RETURN_STATUS_UNEXPECTED("Child column name map cannot be empty!"); | |||
| } | |||
| int eof_count = 0; | |||
| while (eof_count != children_num_) { | |||
| for (int i = 0; i < children_num_; i++) { | |||
| // 1. Throw the eof buffer when meet it | |||
| if (buf->eof() || buf->eoe()) { | |||
| RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf)); | |||
| } | |||
| // 2. Do varification as for column name, column data type and rank of column data | |||
| RETURN_IF_NOT_OK(Verify(i, buf)); | |||
| // 3. Put the data into output_connector | |||
| while (!buf->eoe() && !buf->eof()) { | |||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buf))); | |||
| RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf)); | |||
| } | |||
| // 4. Throw the eoe buffer when meet it | |||
| if (buf->eoe() && (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat))) { | |||
| RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf)); | |||
| } | |||
| // 5. Add eoe buffer after get buffer from all child | |||
| if (i == (children_num_ - 1)) { | |||
| auto eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); | |||
| } | |||
| if (buf->eof()) { | |||
| eof_count++; | |||
| } | |||
| } | |||
| } | |||
| // 6. Add eof buffer in the end manually | |||
| MS_LOG(DEBUG) << "Add the eof buffer manualy in the end."; | |||
| auto eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF); | |||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); | |||
| return Status::OK(); | |||
| } | |||
| Status ConcatOp::Verify(int32_t id, const std::unique_ptr<DataBuffer> &buf) { | |||
| TensorRow new_row; | |||
| buf->GetRow(0, &new_row); | |||
| if (id == 0) { | |||
| // Obtain the column name, data type and data rank in child[0] | |||
| column_name_id_ = child_[id]->column_name_id_map(); | |||
| for (auto item : new_row) { | |||
| data_type_.push_back(item->type()); | |||
| data_rank_.push_back(item->Rank()); | |||
| } | |||
| } else { | |||
| // Compare the column name, data type and data rank with these in child[0] | |||
| if (child_[id]->column_name_id_map() != column_name_id_) { | |||
| RETURN_STATUS_UNEXPECTED("The column name or column order is not the same with previous dataset."); | |||
| } | |||
| int32_t index = 0; | |||
| for (auto item : new_row) { | |||
| if ((item->type() != data_type_[index]) || item->Rank() != data_rank_[index++]) { | |||
| RETURN_STATUS_UNEXPECTED("The data type or data rank is not the same with previous dataset."); | |||
| } | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status ConcatOp::PrepareNodePostAction() { | |||
| RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction()); | |||
| tree_->AddToRepeatStack(shared_from_this()); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,95 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_ | |||
| #define DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "dataset/engine/datasetops/pipeline_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class ConcatOp : public PipelineOp { | |||
| public: | |||
| // The nested builder class inside of the ConcatOp is used to help manage all of the arguments | |||
| // for constructing it. This Concat op is very simple though, so this builder is really just | |||
| // provided for a consistent look and feel for creators of Dataset operators overall. | |||
| class Builder { | |||
| public: | |||
| // Builder constructor. Creates the builder object. | |||
| // @note No default args | |||
| // @return This is a constructor. | |||
| Builder(); | |||
| // Default destructor | |||
| ~Builder() = default; | |||
| // The builder "build" method creates the final object. | |||
| // @return shared_ptr to the new StorageOp object | |||
| Status Build(std::shared_ptr<ConcatOp> *); | |||
| private: | |||
| int32_t builder_op_connector_size_; | |||
| }; | |||
| // Constructor of the ConcatOp. | |||
| // @note The builder class should be used to call it | |||
| // @param op_connector_size - connector size | |||
| explicit ConcatOp(int32_t op_connector_size); | |||
| // Destructor | |||
| ~ConcatOp() = default; | |||
| // A print method typically used for debugging | |||
| // @param out - The output stream to write output to | |||
| // @param show_all - A bool to control if you want to show all info or just a summary | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| // << Stream output operator overload | |||
| // @notes This allows you to write the debug print info using stream operators | |||
| // @param out - reference to the output stream being overloaded | |||
| // @param ro - reference to the ConcatOp to display | |||
| // @return - the output stream must be returned | |||
| friend std::ostream &operator<<(std::ostream &out, const ConcatOp &ro) { | |||
| ro.Print(out, false); | |||
| return out; | |||
| } | |||
| // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will | |||
| // provide the master loop that drives the logic for performing the work | |||
| // @return Status - The error code return | |||
| Status operator()() override; | |||
| // During tree prepare phase, operators may have specific post-operations to perform depending on | |||
| // their role. | |||
| // @notes Derived versions of this function should always call it's superclass version first | |||
| // before providing their own implementations. | |||
| Status PrepareNodePostAction() override; | |||
| private: | |||
| Status Verify(int32_t id, const std::unique_ptr<DataBuffer> &buf); | |||
| int32_t children_num_; // The num of child of parent node. | |||
| std::unordered_map<std::string, int32_t> column_name_id_; // Mapping between col index and col name | |||
| std::vector<DataType> data_type_; | |||
| std::vector<dsize_t> data_rank_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_ | |||
| @@ -44,7 +44,7 @@ from .validators import check, check_batch, check_shuffle, check_map, check_filt | |||
| check_rename, \ | |||
| check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ | |||
| check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \ | |||
| check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset | |||
| check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat | |||
| from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist | |||
| try: | |||
| @@ -147,6 +147,9 @@ class Dataset: | |||
| self._repeat_count = None | |||
| self._sync = False | |||
| def __add__(self, datasets): | |||
| return self.concat(datasets) | |||
| def get_args(self): | |||
| """ | |||
| Returns attributes (member variables) related to the current class. | |||
| @@ -560,6 +563,37 @@ class Dataset: | |||
| raise TypeError("The zip function %s type error!" % (datasets)) | |||
| return ZipDataset(datasets) | |||
| @check_concat | |||
| def concat(self, datasets): | |||
| """ | |||
| Concat the datasets in the input list of datasets, supported using "+" to reload concat operation. | |||
| Note: | |||
| The column name,column data type and rank of column data should be the same in input datasets. | |||
| Args: | |||
| datasets (list or class Dataset): A list of datasets or a single class Dataset | |||
| to be concated together with this dataset. | |||
| Returns: | |||
| ConcatDataset, dataset concated. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| >>> # ds1 and ds2 are instances of Dataset object | |||
| >>> # creates a dataset by concating ds1 and ds2 with "+" operation | |||
| >>> data1 = ds1 + ds2 | |||
| >>> # creates a dataset by concating ds1 and ds2 with concat operation | |||
| >>> data1 = ds1.concat(ds2) | |||
| """ | |||
| if isinstance(datasets, Dataset): | |||
| datasets = [self] + [datasets] | |||
| elif isinstance(datasets, list): | |||
| datasets = [self] + datasets | |||
| else: | |||
| raise TypeError("The concat_dataset function %s type error!" % (datasets)) | |||
| return ConcatDataset(datasets) | |||
| @check_rename | |||
| def rename(self, input_columns, output_columns): | |||
| """ | |||
| @@ -1658,6 +1692,39 @@ class ZipDataset(DatasetOp): | |||
| return args | |||
| class ConcatDataset(DatasetOp): | |||
| """ | |||
| The result of applying concat dataset operator to the input Dataset. | |||
| Args: | |||
| datasets (list): A list of datasets to be concated together. | |||
| Raises: | |||
| TypeError: If dataset is not an instance of Dataset. | |||
| """ | |||
| def __init__(self, datasets): | |||
| super().__init__() | |||
| for dataset in datasets: | |||
| if not isinstance(dataset, Dataset): | |||
| raise TypeError("The parameter %s of concat has type error!" % (dataset)) | |||
| self.datasets = datasets | |||
| for data in datasets: | |||
| self.input.append(data) | |||
| data.output.append(self) | |||
| def get_dataset_size(self): | |||
| """ | |||
| Get the number of batches in an epoch. | |||
| Return: | |||
| Number, number of batches. | |||
| """ | |||
| children_sizes = [c.get_dataset_size() for c in self.input] | |||
| dataset_size = np.sum(children_sizes) | |||
| return dataset_size | |||
| class RenameDataset(DatasetOp): | |||
| """ | |||
| The result of applying Rename operator to the input Dataset. | |||
| @@ -156,6 +156,8 @@ class Iterator: | |||
| op_type = OpName.BARRIER | |||
| elif isinstance(dataset, de.ZipDataset): | |||
| op_type = OpName.ZIP | |||
| elif isinstance(dataset, de.ConcatDataset): | |||
| op_type = OpName.CONCAT | |||
| elif isinstance(dataset, de.MapDataset): | |||
| op_type = OpName.MAP | |||
| elif isinstance(dataset, de.FilterDataset): | |||
| @@ -335,6 +335,10 @@ def create_node(node): | |||
| # Create ZipDataset instance, giving dummy input dataset that will be overrided in the caller. | |||
| pyobj = de.ZipDataset((de.Dataset(), de.Dataset())) | |||
| elif dataset_op == 'ConcatDataset': | |||
| # Create ConcatDataset instance, giving dummy input dataset that will be overrided in the caller. | |||
| pyobj = de.ConcatDataset((de.Dataset(), de.Dataset())) | |||
| elif dataset_op == 'RenameDataset': | |||
| pyobj = de.Dataset().rename(node['input_columns'], node['output_columns']) | |||
| @@ -902,6 +902,26 @@ def check_zip_dataset(method): | |||
| return new_method | |||
| def check_concat(method): | |||
| """check the input arguments of concat_dataset method in `Dataset`.""" | |||
| @wraps(method) | |||
| def new_method(*args, **kwargs): | |||
| param_dict = make_param_dict(method, args, kwargs) | |||
| # check datasets; required argument | |||
| ds = param_dict.get("datasets") | |||
| if ds is None: | |||
| raise ValueError("datasets is not provided.") | |||
| if not isinstance(ds, (list, datasets.Dataset)): | |||
| raise ValueError("datasets is not list or of type Dataset.") | |||
| return method(*args, **kwargs) | |||
| return new_method | |||
| def check_rename(method): | |||
| """check the input arguments of rename.""" | |||
| @@ -66,6 +66,7 @@ SET(DE_UT_SRCS | |||
| take_op_test.cc | |||
| text_file_op_test.cc | |||
| filter_op_test.cc | |||
| concat_op_test.cc | |||
| ) | |||
| add_executable(de_ut_tests ${DE_UT_SRCS}) | |||
| @@ -0,0 +1,125 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "common/common.h" | |||
| #include "common/utils.h" | |||
| #include "dataset/core/client.h" | |||
| #include "gtest/gtest.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace common = mindspore::common; | |||
| using namespace mindspore::dataset; | |||
| using mindspore::MsLogLevel::INFO; | |||
| using mindspore::ExceptionType::NoExceptionType; | |||
| using mindspore::LogStream; | |||
| class MindDataTestConcatOp : public UT::DatasetOpTesting {}; | |||
| TEST_F(MindDataTestConcatOp, TestConcatProject) { | |||
| /* Tree: | |||
| * | |||
| * OpId(2) ConcatOp | |||
| * / \ | |||
| * OpId(0) TFReaderOp OpId(1) TFReaderOp | |||
| * | |||
| * Start with an empty execution tree | |||
| */ | |||
| MS_LOG(INFO) << "UT test TestConcatProject."; | |||
| auto my_tree = std::make_shared<ExecutionTree>(); | |||
| std::string dataset_path; | |||
| dataset_path = datasets_root_path_ + "/testTFTestAllTypes/test.data"; | |||
| // TFReaderOp1 | |||
| std::shared_ptr<TFReaderOp> my_tfreader_op1; | |||
| TFReaderOp::Builder builder1; | |||
| builder1.SetDatasetFilesList({dataset_path}) | |||
| .SetRowsPerBuffer(16) | |||
| .SetWorkerConnectorSize(16) | |||
| .SetNumWorkers(16); | |||
| std::unique_ptr<DataSchema> schema1 = std::make_unique<DataSchema>(); | |||
| schema1->LoadSchemaFile(datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json", {}); | |||
| builder1.SetDataSchema(std::move(schema1)); | |||
| Status rc = builder1.Build(&my_tfreader_op1); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = my_tree->AssociateNode(my_tfreader_op1); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // TFReaderOp2 | |||
| std::shared_ptr<TFReaderOp> my_tfreader_op2; | |||
| TFReaderOp::Builder builder2; | |||
| builder2.SetDatasetFilesList({dataset_path}) | |||
| .SetRowsPerBuffer(16) | |||
| .SetWorkerConnectorSize(16) | |||
| .SetNumWorkers(16); | |||
| std::unique_ptr<DataSchema> schema2 = std::make_unique<DataSchema>(); | |||
| schema2->LoadSchemaFile(datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json", {}); | |||
| builder2.SetDataSchema(std::move(schema2)); | |||
| rc = builder2.Build(&my_tfreader_op2); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = my_tree->AssociateNode(my_tfreader_op2); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // Creating ConcatOp | |||
| std::shared_ptr<ConcatOp> concat_op; | |||
| rc = ConcatOp::Builder().Build(&concat_op); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = my_tree->AssociateNode(concat_op); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = concat_op->AddChild(std::move(my_tfreader_op1)); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = concat_op->AddChild(std::move(my_tfreader_op2)); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = my_tree->AssignRoot(concat_op); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = my_tree->Prepare(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // Launch the tree execution to kick off threads and start running the pipeline | |||
| MS_LOG(INFO) << "Launching my tree."; | |||
| rc = my_tree->Launch(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // Simulate a parse of data from our pipeline. | |||
| std::shared_ptr<DatasetOp> rootNode = my_tree->root(); | |||
| DatasetIterator di(my_tree); | |||
| TensorRow tensor_list; | |||
| rc = di.FetchNextTensorRow(&tensor_list); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| int row_count = 0; | |||
| while (!tensor_list.empty()) { | |||
| MS_LOG(INFO) << "Row display for row #: " << row_count << "."; | |||
| // Display the tensor by calling the printer on it | |||
| for (int i = 0; i < tensor_list.size(); i++) { | |||
| std::ostringstream ss; | |||
| ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl; | |||
| MS_LOG(INFO) << "Tensor print: " << common::SafeCStr(ss.str()) << "."; | |||
| } | |||
| rc = di.FetchNextTensorRow(&tensor_list); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| row_count++; | |||
| } | |||
| ASSERT_EQ(row_count, 24); // Should be 24 rows fetched | |||
| } | |||
| @@ -0,0 +1,377 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.py_transforms as F | |||
| import mindspore.dataset.transforms.c_transforms as C | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore import log as logger | |||
| import numpy as np | |||
| # In generator dataset: Number of rows is 3, its value is 0, 1, 2 | |||
| def generator(): | |||
| for i in range(3): | |||
| yield np.array([i]), | |||
| # In generator_10 dataset: Number of rows is 7, its value is 3, 4, 5 ... 10 | |||
| def generator_10(): | |||
| for i in range(3, 10): | |||
| yield np.array([i]), | |||
| # In generator_20 dataset: Number of rows is 10, its value is 10, 11, 12 ... 20 | |||
| def generator_20(): | |||
| for i in range(10, 20): | |||
| yield np.array([i]), | |||
| def test_concat_01(): | |||
| """ | |||
| Test concat: test concat 2 datasets that have the same column name and data type | |||
| """ | |||
| logger.info("test_concat_01") | |||
| data1 = ds.GeneratorDataset(generator, ["col1"]) | |||
| data2 = ds.GeneratorDataset(generator_10, ["col1"]) | |||
| data3 = data1 + data2 | |||
| # Here i refers to index, d refers to data element | |||
| for i, d in enumerate(data3): | |||
| logger.info("data: %i", d[0][0]) | |||
| assert i == d[0][0] | |||
| assert sum([1 for _ in data3]) == 10 | |||
| def test_concat_02(): | |||
| """ | |||
| Test concat: test concat 2 datasets using concat operation not "+" operation | |||
| """ | |||
| logger.info("test_concat_02") | |||
| data1 = ds.GeneratorDataset(generator, ["col1"]) | |||
| data2 = ds.GeneratorDataset(generator_10, ["col1"]) | |||
| data3 = data1.concat(data2) | |||
| # Here i refers to index, d refers to data element | |||
| for i, d in enumerate(data3): | |||
| logger.info("data: %i", d[0][0]) | |||
| assert i == d[0][0] | |||
| assert sum([1 for _ in data3]) == 10 | |||
| def test_concat_03(): | |||
| """ | |||
| Test concat: test concat dataset that has different column | |||
| """ | |||
| logger.info("test_concat_03") | |||
| data1 = ds.GeneratorDataset(generator, ["col1"]) | |||
| data2 = ds.GeneratorDataset(generator_10, ["col2"]) | |||
| data3 = data1 + data2 | |||
| try: | |||
| for i, d in enumerate(data3): | |||
| pass | |||
| assert False | |||
| except RuntimeError: | |||
| pass | |||
| def test_concat_04(): | |||
| """ | |||
| Test concat: test concat dataset that has different rank | |||
| """ | |||
| logger.info("test_concat_04") | |||
| data1 = ds.GeneratorDataset(generator, ["col1"]) | |||
| data2 = ds.GeneratorDataset(generator_10, ["col2"]) | |||
| data2 = data2.batch(3) | |||
| data3 = data1 + data2 | |||
| try: | |||
| for i, d in enumerate(data3): | |||
| pass | |||
| assert False | |||
| except RuntimeError: | |||
| pass | |||
| def test_concat_05(): | |||
| """ | |||
| Test concat: test concat dataset that has different data type | |||
| """ | |||
| logger.info("test_concat_05") | |||
| data1 = ds.GeneratorDataset(generator, ["col1"]) | |||
| data2 = ds.GeneratorDataset(generator_10, ["col1"]) | |||
| type_cast_op = C.TypeCast(mstype.float32) | |||
| data1 = data1.map(input_columns=["col1"], operations=type_cast_op) | |||
| data3 = data1 + data2 | |||
| try: | |||
| for i, d in enumerate(data3): | |||
| pass | |||
| assert False | |||
| except RuntimeError: | |||
| pass | |||
| def test_concat_06(): | |||
| """ | |||
| Test concat: test concat muti datasets in one time | |||
| """ | |||
| logger.info("test_concat_06") | |||
| data1 = ds.GeneratorDataset(generator, ["col1"]) | |||
| data2 = ds.GeneratorDataset(generator_10, ["col1"]) | |||
| data3 = ds.GeneratorDataset(generator_20, ["col1"]) | |||
| dataset = data1 + data2 + data3 | |||
| # Here i refers to index, d refers to data element | |||
| for i, d in enumerate(dataset): | |||
| logger.info("data: %i", d[0][0]) | |||
| assert i == d[0][0] | |||
| assert sum([1 for _ in dataset]) == 20 | |||
| def test_concat_07(): | |||
| """ | |||
| Test concat: test concat one dataset with multi datasets (datasets list) | |||
| """ | |||
| logger.info("test_concat_07") | |||
| data1 = ds.GeneratorDataset(generator, ["col1"]) | |||
| data2 = ds.GeneratorDataset(generator_10, ["col1"]) | |||
| data3 = ds.GeneratorDataset(generator_20, ["col1"]) | |||
| dataset = [data2] + [data3] | |||
| data4 = data1 + dataset | |||
| # Here i refers to index, d refers to data element | |||
| for i, d in enumerate(data4): | |||
| logger.info("data: %i", d[0][0]) | |||
| assert i == d[0][0] | |||
| assert sum([1 for _ in data4]) == 20 | |||
| def test_concat_08(): | |||
| """ | |||
| Test concat: test concat 2 datasets, and then repeat | |||
| """ | |||
| logger.info("test_concat_08") | |||
| data1 = ds.GeneratorDataset(generator, ["col1"]) | |||
| data2 = ds.GeneratorDataset(generator_10, ["col1"]) | |||
| data3 = data1 + data2 | |||
| data3 = data3.repeat(2) | |||
| # Here i refers to index, d refers to data element | |||
| for i, d in enumerate(data3): | |||
| logger.info("data: %i", d[0][0]) | |||
| assert i % 10 == d[0][0] | |||
| assert sum([1 for _ in data3]) == 20 | |||
| def test_concat_09(): | |||
| """ | |||
| Test concat: test concat 2 datasets, both of them have been repeat before | |||
| """ | |||
| logger.info("test_concat_09") | |||
| data1 = ds.GeneratorDataset(generator, ["col1"]) | |||
| data2 = ds.GeneratorDataset(generator_10, ["col1"]) | |||
| data1 = data1.repeat(2) | |||
| data2 = data2.repeat(2) | |||
| data3 = data1 + data2 | |||
| res = [0, 1, 2, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 3, 4, 5, 6, 7, 8, 9] | |||
| # Here i refers to index, d refers to data element | |||
| for i, d in enumerate(data3): | |||
| logger.info("data: %i", d[0][0]) | |||
| assert res[i] == d[0][0] | |||
| assert sum([1 for _ in data3]) == 20 | |||
| def test_concat_10(): | |||
| """ | |||
| Test concat: test concat 2 datasets, one of them have repeat before | |||
| """ | |||
| logger.info("test_concat_10") | |||
| data1 = ds.GeneratorDataset(generator, ["col1"]) | |||
| data2 = ds.GeneratorDataset(generator_10, ["col1"]) | |||
| data1 = data1.repeat(2) | |||
| data3 = data1 + data2 | |||
| res = [0, 1, 2, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9] | |||
| # Here i refers to index, d refers to data element | |||
| for i, d in enumerate(data3): | |||
| logger.info("data: %i", d[0][0]) | |||
| assert res[i] == d[0][0] | |||
| assert sum([1 for _ in data3]) == 13 | |||
| def test_concat_11(): | |||
| """ | |||
| Test concat: test dataset batch then concat | |||
| """ | |||
| logger.info("test_concat_11") | |||
| data1 = ds.GeneratorDataset(generator, ["col1"]) | |||
| data2 = ds.GeneratorDataset(generator_20, ["col1"]) | |||
| data1 = data1.batch(3) | |||
| data2 = data2.batch(5) | |||
| data3 = data1 + data2 | |||
| res = [0, 10, 15, 20] | |||
| # Here i refers to index, d refers to data element | |||
| for i, d in enumerate(data3): | |||
| logger.info("data: %i", d[0][0]) | |||
| assert res[i] == d[0][0] | |||
| assert sum([1 for _ in data3]) == 3 | |||
| def test_concat_12(): | |||
| """ | |||
| Test concat: test dataset concat then shuffle | |||
| """ | |||
| logger.info("test_concat_12") | |||
| data1 = ds.GeneratorDataset(generator, ["col1"]) | |||
| data2 = ds.GeneratorDataset(generator_10, ["col1"]) | |||
| data1.set_dataset_size(3) | |||
| data2.set_dataset_size(7) | |||
| data3 = data1 + data2 | |||
| res = [8, 6, 2, 5, 0, 4, 9, 3, 7, 1] | |||
| ds.config.set_seed(1) | |||
| assert data3.get_dataset_size() == 10 | |||
| data3 = data3.shuffle(buffer_size=10) | |||
| # Here i refers to index, d refers to data element | |||
| for i, d in enumerate(data3): | |||
| logger.info("data: %i", d[0][0]) | |||
| assert res[i] == d[0][0] | |||
| assert sum([1 for _ in data3]) == 10 | |||
| def test_concat_13(): | |||
| """ | |||
| Test concat: test dataset batch then shuffle and concat | |||
| """ | |||
| logger.info("test_concat_13") | |||
| data1 = ds.GeneratorDataset(generator, ["col1"]) | |||
| data2 = ds.GeneratorDataset(generator_20, ["col1"]) | |||
| data1.set_dataset_size(3) | |||
| data2.set_dataset_size(10) | |||
| data1 = data1.batch(3) | |||
| data2 = data2.batch(5) | |||
| data3 = data1 + data2 | |||
| res = [15, 0, 10] | |||
| ds.config.set_seed(1) | |||
| assert data3.get_dataset_size() == 3 | |||
| data3 = data3.shuffle(buffer_size=int(data3.get_dataset_size())) | |||
| # Here i refers to index, d refers to data element | |||
| for i, d in enumerate(data3): | |||
| logger.info("data: %i", d[0][0]) | |||
| assert res[i] == d[0][0] | |||
| assert sum([1 for _ in data3]) == 3 | |||
| def test_concat_14(): | |||
| """ | |||
| Test concat: create dataset with different dataset folder, and do diffrent operation then concat | |||
| """ | |||
| logger.info("test_concat_14") | |||
| DATA_DIR = "../data/dataset/testPK/data" | |||
| DATA_DIR2 = "../data/dataset/testImageNetData/train/" | |||
| data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_samples=3) | |||
| data2 = ds.ImageFolderDatasetV2(DATA_DIR2, num_samples=2) | |||
| transforms1 = F.ComposeOp([F.Decode(), | |||
| F.Resize((224,224)), | |||
| F.ToTensor()]) | |||
| data1 = data1.map(input_columns=["image"], operations=transforms1()) | |||
| data2 = data2.map(input_columns=["image"], operations=transforms1()) | |||
| data3 = data1 + data2 | |||
| expected, output = [], [] | |||
| for d in data1: | |||
| expected.append(d[0]) | |||
| for d in data2: | |||
| expected.append(d[0]) | |||
| for d in data3: | |||
| output.append(d[0]) | |||
| assert len(expected) == len(output) | |||
| np.array_equal(np.array(output), np.array(expected)) | |||
| assert sum([1 for _ in data3]) == 5 | |||
| assert data3.get_dataset_size() == 5 | |||
| def test_concat_15(): | |||
| """ | |||
| Test concat: create dataset with different format of dataset file, and then concat | |||
| """ | |||
| logger.info("test_concat_15") | |||
| DATA_DIR = "../data/dataset/testPK/data" | |||
| DATA_DIR2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | |||
| data1 = ds.ImageFolderDatasetV2(DATA_DIR) | |||
| data2 = ds.TFRecordDataset(DATA_DIR2, columns_list=["image"]) | |||
| data1 = data1.project(["image"]) | |||
| data3 = data1 + data2 | |||
| assert sum([1 for _ in data3]) == 47 | |||
| if __name__ == "__main__": | |||
| test_concat_01() | |||
| test_concat_02() | |||
| test_concat_03() | |||
| test_concat_04() | |||
| test_concat_05() | |||
| test_concat_06() | |||
| test_concat_07() | |||
| test_concat_08() | |||
| test_concat_09() | |||
| test_concat_10() | |||
| test_concat_11() | |||
| test_concat_12() | |||
| test_concat_13() | |||
| test_concat_14() | |||
| test_concat_15() | |||