| @@ -16,10 +16,9 @@ | |||
| #include "minddata/dataset/api/python/de_pipeline.h" | |||
| #include <algorithm> | |||
| #include <set> | |||
| #include <map> | |||
| #include <set> | |||
| #include "utils/ms_utils.h" | |||
| #include "minddata/dataset/callback/py_ds_callback.h" | |||
| #include "minddata/dataset/core/tensor.h" | |||
| #include "minddata/dataset/engine/cache/cache_client.h" | |||
| @@ -32,15 +31,15 @@ | |||
| #include "minddata/dataset/engine/datasetops/source/celeba_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/cifar_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/clue_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/csv_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/coco_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/csv_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/image_folder_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/manifest_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/mnist_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/random_data_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/text_file_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/voc_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" | |||
| #include "minddata/dataset/kernels/py_func_op.h" | |||
| #include "minddata/dataset/util/random.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| @@ -53,6 +52,7 @@ | |||
| #include "minddata/mindrecord/include/shard_writer.h" | |||
| #include "pybind11/stl.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "utils/ms_utils.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -211,16 +211,17 @@ Status DEPipeline::GetColumnNames(py::list *output) { | |||
| } | |||
| Status DEPipeline::GetNextAsMap(py::dict *output) { | |||
| TensorMap row; | |||
| std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> vec; | |||
| Status s; | |||
| { | |||
| py::gil_scoped_release gil_release; | |||
| s = iterator_->GetNextAsMap(&row); | |||
| s = iterator_->GetNextAsOrderedPair(&vec); | |||
| } | |||
| RETURN_IF_NOT_OK(s); | |||
| // Generate Python dict as return | |||
| for (auto el : row) { | |||
| (*output)[common::SafeCStr(el.first)] = el.second; | |||
| // Generate Python dict, python dict maintains its insertion order | |||
| for (const auto &pair : vec) { | |||
| (*output)[common::SafeCStr(pair.first)] = pair.second; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -614,7 +615,7 @@ Status DEPipeline::FetchMetaFromTensorRow(const std::unordered_map<std::string, | |||
| } | |||
| if (mr_shape.empty()) { | |||
| if (mr_type == "bytes") { // map to int32 when bytes without shape. | |||
| mr_type == "int32"; | |||
| mr_type = "int32"; | |||
| } | |||
| (*schema)[column_name] = {{"type", mr_type}}; | |||
| } else { | |||
| @@ -905,7 +906,7 @@ Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||
| if (py::isinstance<py::int_>(args["batch_size"])) { | |||
| batch_size_ = ToInt(args["batch_size"]); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(batch_size_ > 0, "Error: batch_size is invalid."); | |||
| builder = std::make_shared<BatchOp::Builder>(ToInt(args["batch_size"])); | |||
| builder = std::make_shared<BatchOp::Builder>(batch_size_); | |||
| } else if (py::isinstance<py::function>(args["batch_size"])) { | |||
| builder = std::make_shared<BatchOp::Builder>(1); | |||
| (void)builder->SetBatchSizeFunc(args["batch_size"].cast<py::function>()); | |||
| @@ -920,17 +921,13 @@ Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||
| if (!value.is_none()) { | |||
| if (key == "drop_remainder") { | |||
| (void)builder->SetDrop(ToBool(value)); | |||
| } | |||
| if (key == "num_parallel_workers") { | |||
| } else if (key == "num_parallel_workers") { | |||
| (void)builder->SetNumWorkers(ToInt(value)); | |||
| } | |||
| if (key == "per_batch_map") { | |||
| } else if (key == "per_batch_map") { | |||
| (void)builder->SetBatchMapFunc(value.cast<py::function>()); | |||
| } | |||
| if (key == "input_columns") { | |||
| } else if (key == "input_columns") { | |||
| (void)builder->SetColumnsToMap(ToStringVector(value)); | |||
| } | |||
| if (key == "pad_info") { | |||
| } else if (key == "pad_info") { | |||
| PadInfo pad_info; | |||
| RETURN_IF_NOT_OK(ParsePadInfo(value, &pad_info)); | |||
| (void)builder->SetPaddingMap(pad_info, true); | |||
| @@ -81,6 +81,40 @@ Status IteratorBase::FetchNextTensorRow(TensorRow *out_row) { | |||
| return Status::OK(); | |||
| } | |||
| Status IteratorBase::GetNextAsOrderedPair(std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> *vec) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(vec != nullptr && vec->empty(), "vec is null or non-empty."); | |||
| TensorRow curr_row; | |||
| RETURN_IF_NOT_OK(FetchNextTensorRow(&curr_row)); | |||
| RETURN_OK_IF_TRUE(curr_row.empty()); | |||
| size_t num_cols = curr_row.size(); // num_cols is non-empty. | |||
| if (col_name_id_map_.empty()) col_name_id_map_ = this->GetColumnNameMap(); | |||
| // order the column names according to their ids | |||
| if (column_order_.empty()) { | |||
| const int32_t invalid_col_id = -1; | |||
| column_order_.resize(num_cols, {std::string(), invalid_col_id}); | |||
| for (const auto itr : col_name_id_map_) { | |||
| int32_t ind = itr.second; | |||
| CHECK_FAIL_RETURN_UNEXPECTED(ind < num_cols && ind >= 0, "column id out of bounds."); | |||
| column_order_[ind] = std::make_pair(itr.first, ind); | |||
| } | |||
| // error check, make sure the ids in col_name_id_map are continuous and starts from 0 | |||
| for (const auto &col : column_order_) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(col.second != invalid_col_id, "column ids are not continuous."); | |||
| } | |||
| } | |||
| vec->reserve(num_cols); | |||
| for (const auto &col : column_order_) { | |||
| vec->emplace_back(std::make_pair(col.first, curr_row[col.second])); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Constructor of the DatasetIterator | |||
| DatasetIterator::DatasetIterator(std::shared_ptr<ExecutionTree> exe_tree) | |||
| : IteratorBase(), | |||
| @@ -19,12 +19,14 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/util/status.h" | |||
| #include "minddata/dataset/core/tensor.h" | |||
| #include "minddata/dataset/engine/datasetops/dataset_op.h" | |||
| #include "minddata/dataset/engine/execution_tree.h" | |||
| #include "minddata/dataset/engine/perf/dataset_iterator_tracing.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -61,6 +63,11 @@ class IteratorBase { | |||
| // @return A unordered map from column name to shared pointer to Tensor. | |||
| Status GetNextAsMap(TensorMap *out_map); | |||
| /// \breif return column_name, tensor pair in the order of its column id. | |||
| /// \param[out] vec | |||
| /// \return Error code | |||
| Status GetNextAsOrderedPair(std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> *vec); | |||
| // Getter | |||
| // @return T/F if this iterator is completely done after getting an eof | |||
| bool eof_handled() const { return eof_handled_; } | |||
| @@ -73,6 +80,7 @@ class IteratorBase { | |||
| std::unique_ptr<DataBuffer> curr_buffer_; // holds the current buffer | |||
| bool eof_handled_; // T/F if this op got an eof | |||
| std::unordered_map<std::string, int32_t> col_name_id_map_; | |||
| std::vector<std::pair<std::string, int32_t>> column_order_; // key: column name, val: column id | |||
| }; | |||
| // The DatasetIterator derived class is for fetching rows off the end/root of the execution tree. | |||
| @@ -150,12 +150,15 @@ def check_columns(columns, name): | |||
| Exception: when the value is not correct, otherwise nothing. | |||
| """ | |||
| type_check(columns, (list, str), name) | |||
| if isinstance(columns, list): | |||
| if isinstance(columns, str): | |||
| if not columns: | |||
| raise ValueError("{0} should not be an empty str".format(name)) | |||
| elif isinstance(columns, list): | |||
| if not columns: | |||
| raise ValueError("{0} should not be empty".format(name)) | |||
| for i, column_name in enumerate(columns): | |||
| if not column_name: | |||
| raise ValueError("{0}[{1}] should not be empty".format(name, i)) | |||
| raise ValueError("{0}[{1}] should not be empty.".format(name, i)) | |||
| col_names = ["{0}[{1}]".format(name, i) for i in range(len(columns))] | |||
| type_check_list(columns, (str,), col_names) | |||
| @@ -503,17 +503,13 @@ def check_batch(method): | |||
| if input_columns is not None: | |||
| check_columns(input_columns, "input_columns") | |||
| if len(input_columns) != (len(ins.signature(per_batch_map).parameters) - 1): | |||
| raise ValueError("the signature of per_batch_map should match with input columns") | |||
| if (per_batch_map is None) != (input_columns is None): | |||
| # These two parameters appear together. | |||
| raise ValueError("per_batch_map and input_columns need to be passed in together.") | |||
| if input_columns is not None: | |||
| if not input_columns: # Check whether input_columns is empty. | |||
| raise ValueError("input_columns can not be empty") | |||
| if len(input_columns) != (len(ins.signature(per_batch_map).parameters) - 1): | |||
| raise ValueError("the signature of per_batch_map should match with input columns") | |||
| if output_columns is not None: | |||
| raise ValueError("output_columns is currently not implemented.") | |||
| @@ -12,6 +12,8 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| import numpy as np | |||
| from util import save_and_check_tuple | |||
| import mindspore.dataset as ds | |||
| @@ -155,3 +157,27 @@ def test_case_map_project_map_project(): | |||
| filename = "project_alternate_parallel_inline_result.npz" | |||
| save_and_check_tuple(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) | |||
| def test_column_order(): | |||
| """test the output dict has maintained an insertion order.""" | |||
| def gen_3_cols(num): | |||
| for i in range(num): | |||
| yield (np.array([i * 3]), np.array([i * 3 + 1]), np.array([i * 3 + 2])) | |||
| def test_config(num, col_order): | |||
| dst = ds.GeneratorDataset((lambda: gen_3_cols(num)), ["col1", "col2", "col3"]).batch(batch_size=num) | |||
| dst = dst.project(col_order) | |||
| res = dict() | |||
| for item in dst.create_dict_iterator(num_epochs=1): | |||
| res = item | |||
| return res | |||
| assert list(test_config(1, ["col3", "col2", "col1"]).keys()) == ["col3", "col2", "col1"] | |||
| assert list(test_config(2, ["col1", "col2", "col3"]).keys()) == ["col1", "col2", "col3"] | |||
| assert list(test_config(3, ["col2", "col3", "col1"]).keys()) == ["col2", "col3", "col1"] | |||
| if __name__ == '__main__': | |||
| test_column_order() | |||
| @@ -190,14 +190,13 @@ def test_random_affine_py_exception_non_pil_images(): | |||
| Test RandomAffine: input img is ndarray and not PIL, expected to raise RuntimeError | |||
| """ | |||
| logger.info("test_random_affine_exception_negative_degrees") | |||
| dataset = ds.MnistDataset(MNIST_DATA_DIR, num_parallel_workers=3) | |||
| dataset = ds.MnistDataset(MNIST_DATA_DIR, num_samples=3, num_parallel_workers=3) | |||
| try: | |||
| transform = mindspore.dataset.transforms.py_transforms.Compose([py_vision.ToTensor(), | |||
| py_vision.RandomAffine(degrees=(15, 15))]) | |||
| dataset = dataset.map(operations=transform, input_columns=["image"], num_parallel_workers=3, | |||
| python_multiprocessing=True) | |||
| dataset = dataset.map(operations=transform, input_columns=["image"], num_parallel_workers=3) | |||
| for _ in dataset.create_dict_iterator(num_epochs=1): | |||
| break | |||
| pass | |||
| except RuntimeError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "Pillow image" in str(e) | |||