| @@ -44,7 +44,14 @@ PYBIND_REGISTER( | |||||
| [](DEPipeline &de, const DsOpPtr &dataset_op) { THROW_IF_ERROR(de.AssignRootNode(dataset_op)); }) | [](DEPipeline &de, const DsOpPtr &dataset_op) { THROW_IF_ERROR(de.AssignRootNode(dataset_op)); }) | ||||
| .def("SetBatchParameters", | .def("SetBatchParameters", | ||||
| [](DEPipeline &de, const py::dict &args) { THROW_IF_ERROR(de.SetBatchParameters(args)); }) | [](DEPipeline &de, const py::dict &args) { THROW_IF_ERROR(de.SetBatchParameters(args)); }) | ||||
| .def("LaunchTreeExec", [](DEPipeline &de, int32_t num_epochs) { THROW_IF_ERROR(de.LaunchTreeExec(num_epochs)); }) | |||||
| .def("PrepareTree", [](DEPipeline &de, int32_t num_epochs) { THROW_IF_ERROR(de.PrepareTree(num_epochs)); }) | |||||
| .def("LaunchTreeExec", [](DEPipeline &de) { THROW_IF_ERROR(de.LaunchTreeExec()); }) | |||||
| .def("GetColumnNames", | |||||
| [](DEPipeline &de) { | |||||
| py::list out; | |||||
| THROW_IF_ERROR(de.GetColumnNames(&out)); | |||||
| return out; | |||||
| }) | |||||
| .def("GetNextAsMap", | .def("GetNextAsMap", | ||||
| [](DEPipeline &de) { | [](DEPipeline &de) { | ||||
| py::dict out; | py::dict out; | ||||
| @@ -172,9 +172,11 @@ Status DEPipeline::AddChildToParentNode(const DsOpPtr &child_op, const DsOpPtr & | |||||
| // Function to assign the node as root. | // Function to assign the node as root. | ||||
| Status DEPipeline::AssignRootNode(const DsOpPtr &dataset_op) { return (tree_->AssignRoot(dataset_op)); } | Status DEPipeline::AssignRootNode(const DsOpPtr &dataset_op) { return (tree_->AssignRoot(dataset_op)); } | ||||
| // Function to prepare the tree | |||||
| Status DEPipeline::PrepareTree(const int32_t num_epochs) { return tree_->Prepare(num_epochs); } | |||||
| // Function to launch the tree execution. | // Function to launch the tree execution. | ||||
| Status DEPipeline::LaunchTreeExec(const int32_t num_epochs) { | |||||
| RETURN_IF_NOT_OK(tree_->Prepare(num_epochs)); | |||||
| Status DEPipeline::LaunchTreeExec() { | |||||
| RETURN_IF_NOT_OK(tree_->Launch()); | RETURN_IF_NOT_OK(tree_->Launch()); | ||||
| iterator_ = std::make_unique<DatasetIterator>(tree_); | iterator_ = std::make_unique<DatasetIterator>(tree_); | ||||
| if (iterator_ == nullptr) RETURN_STATUS_UNEXPECTED("Cannot create an Iterator."); | if (iterator_ == nullptr) RETURN_STATUS_UNEXPECTED("Cannot create an Iterator."); | ||||
| @@ -189,6 +191,25 @@ void DEPipeline::PrintTree() { | |||||
| } | } | ||||
| } | } | ||||
| Status DEPipeline::GetColumnNames(py::list *output) { | |||||
| if (!tree_->isPrepared()) { | |||||
| RETURN_STATUS_UNEXPECTED("GetColumnNames: Make sure to call prepare before calling GetColumnNames."); | |||||
| } | |||||
| std::unordered_map<std::string, int32_t> column_name_id_map = tree_->root()->column_name_id_map(); | |||||
| if (column_name_id_map.empty()) | |||||
| RETURN_STATUS_UNEXPECTED("GetColumnNames: Column names was empty. Make sure Prepare is called."); | |||||
| std::vector<std::pair<std::string, int32_t>> column_name_id_vector(column_name_id_map.begin(), | |||||
| column_name_id_map.end()); | |||||
| std::sort(column_name_id_vector.begin(), column_name_id_vector.end(), | |||||
| [](const std::pair<std::string, int32_t> &a, const std::pair<std::string, int32_t> &b) { | |||||
| return a.second < b.second; | |||||
| }); | |||||
| for (auto item : column_name_id_vector) { | |||||
| (*output).append(item.first); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status DEPipeline::GetNextAsMap(py::dict *output) { | Status DEPipeline::GetNextAsMap(py::dict *output) { | ||||
| TensorMap row; | TensorMap row; | ||||
| Status s; | Status s; | ||||
| @@ -92,8 +92,14 @@ class DEPipeline { | |||||
| // Function to assign the node as root. | // Function to assign the node as root. | ||||
| Status AssignRootNode(const DsOpPtr &dataset_op); | Status AssignRootNode(const DsOpPtr &dataset_op); | ||||
| // Function to get the column names in the last node in the tree in order | |||||
| Status GetColumnNames(py::list *output); | |||||
| // Function to prepare the tree for execution | |||||
| Status PrepareTree(const int32_t num_epochs); | |||||
| // Function to launch the tree execution. | // Function to launch the tree execution. | ||||
| Status LaunchTreeExec(int32_t num_epochs); | |||||
| Status LaunchTreeExec(); | |||||
| // Get a row of data as dictionary of column name to the value. | // Get a row of data as dictionary of column name to the value. | ||||
| Status GetNextAsMap(py::dict *output); | Status GetNextAsMap(py::dict *output); | ||||
| @@ -83,7 +83,8 @@ void GeneratorOp::Dealloc() noexcept { | |||||
| PyGILState_STATE gstate; | PyGILState_STATE gstate; | ||||
| gstate = PyGILState_Ensure(); | gstate = PyGILState_Ensure(); | ||||
| // GC the generator object within GIL | // GC the generator object within GIL | ||||
| (void)generator_.dec_ref(); | |||||
| if (generator_function_.ref_count() == 1) generator_function_.dec_ref(); | |||||
| if (generator_.ref_count() == 1) (void)generator_.dec_ref(); | |||||
| // Release GIL | // Release GIL | ||||
| PyGILState_Release(gstate); | PyGILState_Release(gstate); | ||||
| } | } | ||||
| @@ -211,6 +211,13 @@ class ExecutionTree { | |||||
| // @return Bool - true is ExecutionTree is finished | // @return Bool - true is ExecutionTree is finished | ||||
| bool isFinished() const { return tree_state_ == TreeState::kDeTStateFinished; } | bool isFinished() const { return tree_state_ == TreeState::kDeTStateFinished; } | ||||
| // Return if the ExecutionTree is ready. | |||||
| // @return Bool - true is ExecutionTree is ready | |||||
| bool isPrepared() const { | |||||
| return tree_state_ == TreeState::kDeTStateReady || tree_state_ == kDeTStateExecuting || | |||||
| tree_state_ == kDeTStateFinished; | |||||
| } | |||||
| // Set the ExecutionTree to Finished state. | // Set the ExecutionTree to Finished state. | ||||
| void SetFinished() { tree_state_ = TreeState::kDeTStateFinished; } | void SetFinished() { tree_state_ = TreeState::kDeTStateFinished; } | ||||
| @@ -38,7 +38,7 @@ from mindspore._c_expression import typing | |||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from . import samplers | from . import samplers | ||||
| from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp | |||||
| from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp, Iterator | |||||
| from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \ | from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \ | ||||
| check_rename, check_numpyslicesdataset, check_device_send, \ | check_rename, check_numpyslicesdataset, check_device_send, \ | ||||
| check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ | check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ | ||||
| @@ -1203,6 +1203,12 @@ class Dataset: | |||||
| self._repeat_count = device_iter.get_repeat_count() | self._repeat_count = device_iter.get_repeat_count() | ||||
| device_iter.stop() | device_iter.stop() | ||||
| def get_col_names(self): | |||||
| """ | |||||
| Get names of the columns in the dataset | |||||
| """ | |||||
| return Iterator(self).get_col_names() | |||||
| def output_shapes(self): | def output_shapes(self): | ||||
| """ | """ | ||||
| Get the shapes of output data. | Get the shapes of output data. | ||||
| @@ -93,7 +93,7 @@ class Iterator: | |||||
| root = self.__convert_node_postorder(self.dataset) | root = self.__convert_node_postorder(self.dataset) | ||||
| self.depipeline.AssignRootNode(root) | self.depipeline.AssignRootNode(root) | ||||
| self.depipeline.LaunchTreeExec(self.num_epochs) | |||||
| self.depipeline.PrepareTree(self.num_epochs) | |||||
| self._index = 0 | self._index = 0 | ||||
| def stop(self): | def stop(self): | ||||
| @@ -276,6 +276,9 @@ class Iterator: | |||||
| def num_classes(self): | def num_classes(self): | ||||
| return self.depipeline.GetNumClasses() | return self.depipeline.GetNumClasses() | ||||
| def get_col_names(self): | |||||
| return self.depipeline.GetColumnNames() | |||||
| def __deepcopy__(self, memo): | def __deepcopy__(self, memo): | ||||
| return self | return self | ||||
| @@ -283,6 +286,10 @@ class SaveOp(Iterator): | |||||
| """ | """ | ||||
| The derived class of Iterator with dict type. | The derived class of Iterator with dict type. | ||||
| """ | """ | ||||
| def __init__(self, dataset, num_epochs=-1): | |||||
| super().__init__(dataset, num_epochs) | |||||
| self.depipeline.LaunchTreeExec() | |||||
| def get_next(self): | def get_next(self): | ||||
| pass | pass | ||||
| @@ -298,6 +305,10 @@ class DictIterator(Iterator): | |||||
| """ | """ | ||||
| The derived class of Iterator with dict type. | The derived class of Iterator with dict type. | ||||
| """ | """ | ||||
| def __init__(self, dataset, num_epochs=-1): | |||||
| super().__init__(dataset, num_epochs) | |||||
| self.depipeline.LaunchTreeExec() | |||||
| def check_node_type(self, node): | def check_node_type(self, node): | ||||
| pass | pass | ||||
| @@ -328,6 +339,7 @@ class TupleIterator(Iterator): | |||||
| columns = [columns] | columns = [columns] | ||||
| dataset = dataset.project(columns) | dataset = dataset.project(columns) | ||||
| super().__init__(dataset, num_epochs) | super().__init__(dataset, num_epochs) | ||||
| self.depipeline.LaunchTreeExec() | |||||
| def __iter__(self): | def __iter__(self): | ||||
| return self | return self | ||||
| @@ -0,0 +1,198 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| import numpy as np | |||||
| import mindspore.dataset as ds | |||||
| import mindspore.dataset.transforms.vision.c_transforms as vision | |||||
| CELEBA_DIR = "../data/dataset/testCelebAData" | |||||
| CIFAR10_DIR = "../data/dataset/testCifar10Data" | |||||
| CIFAR100_DIR = "../data/dataset/testCifar100Data" | |||||
| CLUE_DIR = "../data/dataset/testCLUE/afqmc/train.json" | |||||
| COCO_DIR = "../data/dataset/testCOCO/train" | |||||
| COCO_ANNOTATION = "../data/dataset/testCOCO/annotations/train.json" | |||||
| CSV_DIR = "../data/dataset/testCSV/1.csv" | |||||
| IMAGE_FOLDER_DIR = "../data/dataset/testPK/data/" | |||||
| MANIFEST_DIR = "../data/dataset/testManifestData/test.manifest" | |||||
| MNIST_DIR = "../data/dataset/testMnistData" | |||||
| TFRECORD_DIR = ["../data/dataset/testTFTestAllTypes/test.data"] | |||||
| TFRECORD_SCHEMA = "../data/dataset/testTFTestAllTypes/datasetSchema.json" | |||||
| VOC_DIR = "../data/dataset/testVOC2012" | |||||
| def test_get_column_name_celeba(): | |||||
| data = ds.CelebADataset(CELEBA_DIR) | |||||
| assert data.get_col_names() == ["image", "attr"] | |||||
| def test_get_column_name_cifar10(): | |||||
| data = ds.Cifar10Dataset(CIFAR10_DIR) | |||||
| assert data.get_col_names() == ["image", "label"] | |||||
| def test_get_column_name_cifar100(): | |||||
| data = ds.Cifar100Dataset(CIFAR100_DIR) | |||||
| assert data.get_col_names() == ["image", "coarse_label", "fine_label"] | |||||
| def test_get_column_name_clue(): | |||||
| data = ds.CLUEDataset(CLUE_DIR, task="AFQMC", usage="train") | |||||
| assert data.get_col_names() == ["label", "sentence1", "sentence2"] | |||||
| def test_get_column_name_coco(): | |||||
| data = ds.CocoDataset(COCO_DIR, annotation_file=COCO_ANNOTATION, task="Detection", | |||||
| decode=True, shuffle=False) | |||||
| assert data.get_col_names() == ["image", "bbox", "category_id", "iscrowd"] | |||||
| def test_get_column_name_csv(): | |||||
| data = ds.CSVDataset(CSV_DIR) | |||||
| assert data.get_col_names() == ["1", "2", "3", "4"] | |||||
| data = ds.CSVDataset(CSV_DIR, column_names=["col1", "col2", "col3", "col4"]) | |||||
| assert data.get_col_names() == ["col1", "col2", "col3", "col4"] | |||||
| def test_get_column_name_generator(): | |||||
| def generator(): | |||||
| for i in range(64): | |||||
| yield (np.array([i]),) | |||||
| data = ds.GeneratorDataset(generator, ["data"]) | |||||
| assert data.get_col_names() == ["data"] | |||||
| def test_get_column_name_imagefolder(): | |||||
| data = ds.ImageFolderDatasetV2(IMAGE_FOLDER_DIR) | |||||
| assert data.get_col_names() == ["image", "label"] | |||||
| def test_get_column_name_iterator(): | |||||
| data = ds.Cifar10Dataset(CIFAR10_DIR) | |||||
| itr = data.create_tuple_iterator(num_epochs=1) | |||||
| assert itr.get_col_names() == ["image", "label"] | |||||
| itr = data.create_dict_iterator(num_epochs=1) | |||||
| assert itr.get_col_names() == ["image", "label"] | |||||
| def test_get_column_name_manifest(): | |||||
| data = ds.ManifestDataset(MANIFEST_DIR) | |||||
| assert data.get_col_names() == ["image", "label"] | |||||
| def test_get_column_name_map(): | |||||
| data = ds.Cifar10Dataset(CIFAR10_DIR) | |||||
| center_crop_op = vision.CenterCrop(10) | |||||
| data = data.map(input_columns=["image"], operations=center_crop_op) | |||||
| assert data.get_col_names() == ["image", "label"] | |||||
| data = ds.Cifar10Dataset(CIFAR10_DIR) | |||||
| data = data.map(input_columns=["image"], operations=center_crop_op, output_columns=["image"]) | |||||
| assert data.get_col_names() == ["image", "label"] | |||||
| data = ds.Cifar10Dataset(CIFAR10_DIR) | |||||
| data = data.map(input_columns=["image"], operations=center_crop_op, output_columns=["col1"]) | |||||
| assert data.get_col_names() == ["col1", "label"] | |||||
| data = ds.Cifar10Dataset(CIFAR10_DIR) | |||||
| data = data.map(input_columns=["image"], operations=center_crop_op, output_columns=["col1", "col2"], | |||||
| columns_order=["col2", "col1"]) | |||||
| assert data.get_col_names() == ["col2", "col1"] | |||||
| def test_get_column_name_mnist(): | |||||
| data = ds.MnistDataset(MNIST_DIR) | |||||
| assert data.get_col_names() == ["image", "label"] | |||||
| def test_get_column_name_numpy_slices(): | |||||
| np_data = {"a": [1, 2], "b": [3, 4]} | |||||
| data = ds.NumpySlicesDataset(np_data, shuffle=False) | |||||
| assert data.get_col_names() == ["a", "b"] | |||||
| data = ds.NumpySlicesDataset([1, 2, 3], shuffle=False) | |||||
| assert data.get_col_names() == ["column_0"] | |||||
| def test_get_column_name_tfrecord(): | |||||
| data = ds.TFRecordDataset(TFRECORD_DIR, TFRECORD_SCHEMA) | |||||
| assert data.get_col_names() == ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", "col_sint16", "col_sint32", | |||||
| "col_sint64"] | |||||
| data = ds.TFRecordDataset(TFRECORD_DIR, TFRECORD_SCHEMA, | |||||
| columns_list=["col_sint16", "col_sint64", "col_2d", "col_binary"]) | |||||
| assert data.get_col_names() == ["col_sint16", "col_sint64", "col_2d", "col_binary"] | |||||
| data = ds.TFRecordDataset(TFRECORD_DIR) | |||||
| assert data.get_col_names() == ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", "col_sint16", "col_sint32", | |||||
| "col_sint64", "col_sint8"] | |||||
| s = ds.Schema() | |||||
| s.add_column("line", "string", []) | |||||
| s.add_column("words", "string", [-1]) | |||||
| s.add_column("chinese", "string", []) | |||||
| data = ds.TFRecordDataset("../data/dataset/testTextTFRecord/text.tfrecord", shuffle=False, schema=s) | |||||
| assert data.get_col_names() == ["line", "words", "chinese"] | |||||
| def test_get_column_name_to_device(): | |||||
| data = ds.Cifar10Dataset(CIFAR10_DIR) | |||||
| data = data.to_device() | |||||
| assert data.get_col_names() == ["image", "label"] | |||||
| def test_get_column_name_voc(): | |||||
| data = ds.VOCDataset(VOC_DIR, task="Segmentation", mode="train", decode=True, shuffle=False) | |||||
| assert data.get_col_names() == ["image", "target"] | |||||
| def test_get_column_name_project(): | |||||
| data = ds.Cifar10Dataset(CIFAR10_DIR) | |||||
| assert data.get_col_names() == ["image", "label"] | |||||
| data = data.project(columns=["image"]) | |||||
| assert data.get_col_names() == ["image"] | |||||
| def test_get_column_name_rename(): | |||||
| data = ds.Cifar10Dataset(CIFAR10_DIR) | |||||
| assert data.get_col_names() == ["image", "label"] | |||||
| data = data.rename(["image", "label"], ["test1", "test2"]) | |||||
| assert data.get_col_names() == ["test1", "test2"] | |||||
| def test_get_column_name_zip(): | |||||
| data1 = ds.Cifar10Dataset(CIFAR10_DIR) | |||||
| assert data1.get_col_names() == ["image", "label"] | |||||
| data2 = ds.CSVDataset(CSV_DIR) | |||||
| assert data2.get_col_names() == ["1", "2", "3", "4"] | |||||
| data = ds.zip((data1, data2)) | |||||
| assert data.get_col_names() == ["image", "label", "1", "2", "3", "4"] | |||||
| if __name__ == "__main__": | |||||
| test_get_column_name_celeba() | |||||
| test_get_column_name_cifar10() | |||||
| test_get_column_name_cifar100() | |||||
| test_get_column_name_clue() | |||||
| test_get_column_name_coco() | |||||
| test_get_column_name_csv() | |||||
| test_get_column_name_generator() | |||||
| test_get_column_name_imagefolder() | |||||
| test_get_column_name_iterator() | |||||
| test_get_column_name_manifest() | |||||
| test_get_column_name_map() | |||||
| test_get_column_name_mnist() | |||||
| test_get_column_name_numpy_slices() | |||||
| test_get_column_name_tfrecord() | |||||
| test_get_column_name_to_device() | |||||
| test_get_column_name_voc() | |||||
| test_get_column_name_project() | |||||
| test_get_column_name_rename() | |||||
| test_get_column_name_zip() | |||||