Merge pull request !3049 from liyong126/dataset_save_optags/v0.6.0-beta
| @@ -42,11 +42,17 @@ | |||
| #include "minddata/dataset/util/status.h" | |||
| #include "minddata/mindrecord/include/shard_category.h" | |||
| #include "minddata/mindrecord/include/shard_distributed_sample.h" | |||
| #include "minddata/mindrecord/include/shard_header.h" | |||
| #include "minddata/mindrecord/include/shard_index_generator.h" | |||
| #include "minddata/mindrecord/include/shard_sample.h" | |||
| #include "minddata/mindrecord/include/shard_shuffle.h" | |||
| #include "minddata/mindrecord/include/shard_writer.h" | |||
| #include "pybind11/stl.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| using json = nlohmann::json; | |||
| using pFunction = Status (DEPipeline::*)(const py::dict &, std::shared_ptr<DatasetOp> *, std::shared_ptr<DatasetOp> *); | |||
| static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = { | |||
| @@ -355,6 +361,226 @@ Status DEPipeline::ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetO | |||
| return Status::OK(); | |||
| } | |||
| Status DEPipeline::SaveDataset(const std::vector<std::string> &file_names, const std::string &file_type) { | |||
| Status s; | |||
| auto mr_header = std::make_shared<mindrecord::ShardHeader>(); | |||
| auto mr_writer = std::make_unique<mindrecord::ShardWriter>(); | |||
| std::vector<std::string> blob_fields; | |||
| uint64_t mr_schema_id = 0; | |||
| if (mindrecord::SUCCESS != mindrecord::ShardWriter::initialize(&mr_writer, file_names)) { | |||
| RETURN_STATUS_UNEXPECTED("Error: failed to initialize ShardWriter."); | |||
| } | |||
| TensorRow row; | |||
| std::unordered_map<std::string, int32_t> column_name_id_map = | |||
| iterator_->GetColumnNameMap(); // map of column name, id | |||
| bool first_loop = true; // build schema in first loop | |||
| do { | |||
| json row_raw_data; | |||
| std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> row_bin_data; | |||
| { | |||
| py::gil_scoped_release gil_release; | |||
| s = iterator_->FetchNextTensorRow(&row); | |||
| } | |||
| RETURN_IF_NOT_OK(s); | |||
| if (row.empty()) break; | |||
| if (first_loop) { | |||
| json mr_json; | |||
| std::vector<std::string> index_fields; | |||
| s = FetchMetaFromTensorRow(column_name_id_map, row, &mr_json, &index_fields); | |||
| RETURN_IF_NOT_OK(s); | |||
| mindrecord::ShardHeader::initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id); | |||
| mr_writer->SetShardHeader(mr_header); | |||
| first_loop = false; | |||
| } | |||
| // construct data | |||
| if (!row.empty()) { // write data | |||
| s = FetchDataFromTensorRow(row, column_name_id_map, &row_raw_data, &row_bin_data); | |||
| RETURN_IF_NOT_OK(s); | |||
| std::shared_ptr<std::vector<uint8_t>> output_bin_data; | |||
| mr_writer->MergeBlobData(blob_fields, row_bin_data, &output_bin_data); | |||
| std::map<std::uint64_t, std::vector<json>> raw_data; | |||
| raw_data.insert(std::pair<uint64_t, std::vector<json>>(mr_schema_id, std::vector<json>{row_raw_data})); | |||
| std::vector<std::vector<uint8_t>> bin_data; | |||
| if (nullptr != output_bin_data) { | |||
| bin_data.emplace_back(*output_bin_data); | |||
| } | |||
| mr_writer->WriteRawData(raw_data, bin_data); | |||
| } | |||
| } while (!row.empty()); | |||
| mr_writer->Commit(); | |||
| mindrecord::ShardIndexGenerator::finalize(file_names); | |||
| return Status::OK(); | |||
| } | |||
| Status DEPipeline::FetchDataFromTensorRow(const TensorRow &row, | |||
| const std::unordered_map<std::string, int32_t> &column_name_id_map, | |||
| json *row_raw_data, | |||
| std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> *row_bin_data) { | |||
| if (row_raw_data == nullptr) { | |||
| RETURN_STATUS_UNEXPECTED("error: row raw data is NULL."); | |||
| } | |||
| if (row_bin_data == nullptr) { | |||
| RETURN_STATUS_UNEXPECTED("error: row bin data is NULL."); | |||
| } | |||
| if (column_name_id_map.empty()) { | |||
| RETURN_STATUS_UNEXPECTED("Error: column not found"); | |||
| } | |||
| Status s; | |||
| for (auto &col : column_name_id_map) { | |||
| auto idx = col.second; | |||
| auto column_name = col.first; | |||
| auto &tensor = row[idx]; | |||
| auto column_type = tensor->type(); | |||
| std::unique_ptr<std::vector<uint8_t>> data_ptr; | |||
| if (column_type == DataType::DE_INT8) { | |||
| std::unique_ptr<int32_t> data; | |||
| std::unique_ptr<int8_t> dummy; | |||
| s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); | |||
| RETURN_IF_NOT_OK(s); | |||
| if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | |||
| } else if (column_type == DataType::DE_INT16) { | |||
| std::unique_ptr<int32_t> data; | |||
| std::unique_ptr<int16_t> dummy; | |||
| s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); | |||
| RETURN_IF_NOT_OK(s); | |||
| if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | |||
| } else if (column_type == DataType::DE_UINT16) { | |||
| std::unique_ptr<int32_t> data; | |||
| std::unique_ptr<uint16_t> dummy; | |||
| s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); | |||
| RETURN_IF_NOT_OK(s); | |||
| if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | |||
| } else if (column_type == DataType::DE_UINT8) { | |||
| std::unique_ptr<uint8_t> data, dummy; | |||
| s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); | |||
| RETURN_IF_NOT_OK(s); | |||
| if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | |||
| } else if (column_type == DataType::DE_INT32) { | |||
| std::unique_ptr<int32_t> data, dummy; | |||
| s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); | |||
| RETURN_IF_NOT_OK(s); | |||
| if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | |||
| } else if (column_type == DataType::DE_UINT32) { | |||
| std::unique_ptr<int64_t> data; | |||
| std::unique_ptr<uint32_t> dummy; | |||
| s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); | |||
| RETURN_IF_NOT_OK(s); | |||
| if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | |||
| } else if (column_type == DataType::DE_INT64) { | |||
| std::unique_ptr<int64_t> data, dummy; | |||
| s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); | |||
| RETURN_IF_NOT_OK(s); | |||
| if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | |||
| } else if (column_type == DataType::DE_FLOAT32) { | |||
| std::unique_ptr<float> data, dummy; | |||
| s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); | |||
| RETURN_IF_NOT_OK(s); | |||
| if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | |||
| } else if (column_type == DataType::DE_FLOAT64) { | |||
| std::unique_ptr<double> data, dummy; | |||
| s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); | |||
| RETURN_IF_NOT_OK(s); | |||
| if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | |||
| } else if (column_type == DataType::DE_STRING) { | |||
| auto buffer = tensor->GetStringsBuffer(); | |||
| std::string ss(reinterpret_cast<const char *>(buffer)); // assume scalar string tensor | |||
| (*row_raw_data)[column_name] = std::move(ss); | |||
| continue; | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Got unexpected type when casting data."); | |||
| } | |||
| RETURN_IF_NOT_OK(s); | |||
| if (data_ptr != nullptr) { | |||
| (*row_bin_data)[column_name] = std::move(data_ptr); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| template <typename T, typename S> | |||
| Status DEPipeline::TransfromTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements, | |||
| std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr, | |||
| std::unique_ptr<S> *s, bool need_convert) { | |||
| if (nullptr == src) { | |||
| RETURN_STATUS_UNEXPECTED("Error: buffer of Tensor is NULL."); | |||
| } | |||
| *data_ptr = std::make_unique<std::vector<uint8_t>>(num_of_elements * sizeof(T)); | |||
| if (need_convert) { | |||
| auto tmp_ptr = std::make_unique<std::vector<uint8_t>>(num_of_elements * sizeof(S)); | |||
| std::copy(src, src + sizeof(S) * num_of_elements, tmp_ptr->begin()); | |||
| auto s_ptr = reinterpret_cast<S *>(&(*(tmp_ptr->begin()))); | |||
| auto el = std::make_unique<T>(); | |||
| for (uint32_t i = 0; i < num_of_elements; ++i) { | |||
| *el = *(s_ptr + i); | |||
| auto t_ptr = reinterpret_cast<uint8_t *>(el.get()); | |||
| for (uint32_t j = 0; j < sizeof(T); ++j) { | |||
| *((*data_ptr)->begin() + i * sizeof(T) + j) = *(t_ptr + j); | |||
| } | |||
| } | |||
| } else { | |||
| std::copy(src, src + sizeof(T) * num_of_elements, (*data_ptr)->begin()); | |||
| } | |||
| if (shape.empty()) { | |||
| *data = std::make_unique<T>(); | |||
| auto t_ptr = reinterpret_cast<uint8_t *>((*data).get()); | |||
| for (uint32_t i = 0; i < sizeof(T); ++i) { | |||
| *(t_ptr + i) = *((*data_ptr)->begin() + i); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status DEPipeline::FetchMetaFromTensorRow(const std::unordered_map<std::string, int32_t> &column_name_id_map, | |||
| const TensorRow &row, json *schema, std::vector<std::string> *index_fields) { | |||
| if (schema == nullptr) { | |||
| RETURN_STATUS_UNEXPECTED("error: schema is NULL."); | |||
| } | |||
| if (index_fields == nullptr) { | |||
| RETURN_STATUS_UNEXPECTED("error: index fields is NULL."); | |||
| } | |||
| if (column_name_id_map.empty()) { | |||
| RETURN_STATUS_UNEXPECTED("Error: column not found."); | |||
| } | |||
| for (auto &col : column_name_id_map) { | |||
| auto idx = col.second; | |||
| auto column_name = col.first; | |||
| auto &tensor = row[idx]; | |||
| auto column_type = tensor->type(); | |||
| auto column_shape = tensor->shape(); | |||
| std::string mr_type; | |||
| auto shapes = column_shape.AsVector(); | |||
| std::vector<int> mr_shape(shapes.begin(), shapes.end()); | |||
| std::string el = column_type.ToString(); | |||
| if (mindrecord::kTypesMap.find(el) == mindrecord::kTypesMap.end()) { | |||
| std::string err_msg("Error: can not support data type: " + el); | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } else { | |||
| mr_type = mindrecord::kTypesMap.at(el); | |||
| } | |||
| if (mr_shape.empty()) { | |||
| if (mr_type == "bytes") { // map to int32 when bytes without shape. | |||
| mr_type == "int32"; | |||
| } | |||
| (*schema)[column_name] = {{"type", mr_type}}; | |||
| } else { | |||
| if (mr_type == "string") { // mindrecord can not support string with shape. | |||
| std::string err_msg("Error: mindrecord can not support multi-dimensional string tensor."); | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| if (mr_type == "bytes") { // ignore shape of bytes in minrecord | |||
| (*schema)[column_name] = {{"type", mr_type}}; | |||
| } else { | |||
| (*schema)[column_name] = {{"type", mr_type}, {"shape", mr_shape}}; | |||
| } | |||
| } | |||
| if (mr_type == "bytes" || !mr_shape.empty()) continue; | |||
| index_fields->emplace_back(column_name); // candidate of index fields | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status DEPipeline::BuildMindrecordSamplerChain(const py::handle &handle, | |||
| std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators, | |||
| int num_padded) { | |||
| @@ -17,6 +17,7 @@ | |||
| #define DATASET_API_DE_PIPELINE_H_ | |||
| #include <iostream> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <stack> | |||
| #include <string> | |||
| @@ -33,6 +34,7 @@ | |||
| namespace py = pybind11; | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| using json = nlohmann::json; | |||
| using DsOpPtr = std::shared_ptr<DatasetOp>; | |||
| class CacheClient; | |||
| @@ -100,6 +102,8 @@ class DEPipeline { | |||
| Status GetOutputTypes(py::list *output); | |||
| Status SaveDataset(const std::vector<std::string> &file_names, const std::string &file_type); | |||
| int GetDatasetSize() const; | |||
| int GetBatchSize() const; | |||
| @@ -110,6 +114,18 @@ class DEPipeline { | |||
| Status ParseMindRecordOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| template <typename T, typename S> | |||
| Status TransfromTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements, | |||
| std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr, | |||
| std::unique_ptr<S> *s, bool need_convert = false); | |||
| Status FetchMetaFromTensorRow(const std::unordered_map<std::string, int32_t> &column_name_id_map, | |||
| const TensorRow &row, json *schema, std::vector<std::string> *index_fields); | |||
| Status FetchDataFromTensorRow(const TensorRow &row, | |||
| const std::unordered_map<std::string, int32_t> &column_name_id_map, json *row_raw_data, | |||
| std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> *row_bin_data); | |||
| Status BuildMindrecordSamplerChain(const py::handle &handle, | |||
| std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators, | |||
| int num_padded); | |||
| @@ -184,7 +184,11 @@ void bindDEPipeline(py::module *m) { | |||
| .def("GetDatasetSize", &DEPipeline::GetDatasetSize) | |||
| .def("GetBatchSize", &DEPipeline::GetBatchSize) | |||
| .def("GetNumClasses", &DEPipeline::GetNumClasses) | |||
| .def("GetRepeatCount", &DEPipeline::GetRepeatCount); | |||
| .def("GetRepeatCount", &DEPipeline::GetRepeatCount) | |||
| .def("SaveDataset", [](DEPipeline &de, const std::vector<std::string> &file_names, const std::string &file_type) { | |||
| THROW_IF_ERROR(de.SaveDataset(file_names, file_type)); | |||
| return true; | |||
| }); | |||
| } | |||
| void bindDatasetOps(py::module *m) { | |||
| (void)py::class_<TFReaderOp, DatasetOp, std::shared_ptr<TFReaderOp>>(*m, "TFReaderOp") | |||
| @@ -312,6 +312,11 @@ class Tensor { | |||
| // @return const unsigned char* | |||
| const unsigned char *GetBuffer() const; | |||
| // Skip the offsets and returns the start of the buffer where the real strings is stored. Caller needs to check if the | |||
| // tensor's type is a string, otherwise undefined address would be returned. | |||
| // @return address of the first string of the tensor. | |||
| uchar *GetStringsBuffer() const { return data_ + kOffsetSize * shape_.NumOfElements() + kOffsetSize; } | |||
| // Getter of the type | |||
| // @return | |||
| DataType type() const { return type_; } | |||
| @@ -643,11 +648,6 @@ class Tensor { | |||
| // @return length of the string | |||
| Status GetStringAt(dsize_t index, uchar **string_start, offset_t *length) const; | |||
| // Skip the offsets and returns the start of the buffer where the real strings is stored. Caller needs to check if the | |||
| // tensor's type is a string, otherwise undefined address would be returned. | |||
| // @return address of the first string of the tensor. | |||
| uchar *GetStringsBuffer() const { return data_ + kOffsetSize * shape_.NumOfElements() + kOffsetSize; } | |||
| // all access to shape_ should be via shape | |||
| TensorShape shape_; | |||
| // data type of tensor | |||
| @@ -215,7 +215,7 @@ void MindRecordOp::Print(std::ostream &out, bool show_all) const { | |||
| // Call the super class for displaying any common detailed info | |||
| ParallelOp::Print(out, show_all); | |||
| // Then show any custom derived-internal stuff | |||
| out << "\n Dataset file : "; | |||
| out << "\nDataset file : "; | |||
| for (auto &file : dataset_file_) { | |||
| out << file << " "; | |||
| } | |||
| @@ -137,6 +137,10 @@ const std::set<std::string> kScalarFieldTypeSet = {"string", "int32", "int64", " | |||
| // number field list | |||
| const std::set<std::string> kNumberFieldTypeSet = {"int32", "int64", "float32", "float64"}; | |||
| const std::unordered_map<std::string, std::string> kTypesMap = { | |||
| {"bool", "int32"}, {"int8", "int32"}, {"uint8", "bytes"}, {"int16", "int32"}, | |||
| {"uint16", "int32"}, {"int32", "int32"}, {"uint32", "int64"}, {"int64", "int64"}, | |||
| {"float16", "float32"}, {"float32", "float32"}, {"float64", "float64"}, {"string", "string"}}; | |||
| /// \brief split a string using a character | |||
| /// \param[in] field target string | |||
| /// \param[in] separator a character for spliting | |||
| @@ -124,6 +124,10 @@ class ShardHeader { | |||
| MSRStatus FileToPages(const std::string dump_file_name); | |||
| static MSRStatus initialize(const std::shared_ptr<ShardHeader> *header_ptr, const json &schema, | |||
| const std::vector<std::string> &index_fields, std::vector<std::string> &blob_fields, | |||
| uint64_t &schema_id); | |||
| private: | |||
| MSRStatus InitializeHeader(const std::vector<json> &headers, bool load_dataset); | |||
| @@ -57,6 +57,8 @@ class ShardIndexGenerator { | |||
| /// \brief create databases for indexes | |||
| MSRStatus WriteToDatabase(); | |||
| static MSRStatus finalize(const std::vector<std::string> file_names); | |||
| private: | |||
| static int Callback(void *not_used, int argc, char **argv, char **az_col_name); | |||
| @@ -108,6 +108,13 @@ class ShardWriter { | |||
| std::map<uint64_t, std::vector<py::handle>> &blob_data, bool sign = true, | |||
| bool parallel_writer = false); | |||
| MSRStatus MergeBlobData(const std::vector<string> &blob_fields, | |||
| const std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> &row_bin_data, | |||
| std::shared_ptr<std::vector<uint8_t>> *output); | |||
| static MSRStatus initialize(const std::unique_ptr<ShardWriter> *writer_ptr, | |||
| const std::vector<std::string> &file_names); | |||
| private: | |||
| /// \brief write shard header data to disk | |||
| MSRStatus WriteShardHeader(); | |||
| @@ -622,5 +622,21 @@ void ShardIndexGenerator::DatabaseWriter() { | |||
| shard_no = task_++; | |||
| } | |||
| } | |||
| MSRStatus ShardIndexGenerator::finalize(const std::vector<std::string> file_names) { | |||
| if (file_names.empty()) { | |||
| MS_LOG(ERROR) << "Mindrecord files is empty."; | |||
| return FAILED; | |||
| } | |||
| ShardIndexGenerator sg{file_names[0]}; | |||
| if (SUCCESS != sg.Build()) { | |||
| MS_LOG(ERROR) << "Failed to build index generator."; | |||
| return FAILED; | |||
| } | |||
| if (SUCCESS != sg.WriteToDatabase()) { | |||
| MS_LOG(ERROR) << "Failed to write to database."; | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| } // namespace mindrecord | |||
| } // namespace mindspore | |||
| @@ -637,6 +637,42 @@ MSRStatus ShardWriter::WriteRawDataPreCheck(std::map<uint64_t, std::vector<json> | |||
| *row_count = std::get<2>(v); | |||
| return SUCCESS; | |||
| } | |||
| MSRStatus ShardWriter::MergeBlobData(const std::vector<string> &blob_fields, | |||
| const std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> &row_bin_data, | |||
| std::shared_ptr<std::vector<uint8_t>> *output) { | |||
| if (blob_fields.empty()) { | |||
| return SUCCESS; | |||
| } | |||
| if (blob_fields.size() == 1) { | |||
| auto &blob = row_bin_data.at(blob_fields[0]); | |||
| auto blob_size = blob->size(); | |||
| *output = std::make_shared<std::vector<uint8_t>>(blob_size); | |||
| std::copy(blob->begin(), blob->end(), (*output)->begin()); | |||
| } else { | |||
| size_t output_size = 0; | |||
| for (auto &field : blob_fields) { | |||
| output_size += row_bin_data.at(field)->size(); | |||
| } | |||
| output_size += blob_fields.size() * sizeof(uint64_t); | |||
| *output = std::make_shared<std::vector<uint8_t>>(output_size); | |||
| std::vector<uint8_t> buf(sizeof(uint64_t), 0); | |||
| size_t idx = 0; | |||
| for (auto &field : blob_fields) { | |||
| auto &blob = row_bin_data.at(field); | |||
| uint64_t blob_size = blob->size(); | |||
| // big edian | |||
| for (size_t i = 0; i < buf.size(); ++i) { | |||
| buf[buf.size() - 1 - i] = std::numeric_limits<uint8_t>::max() & blob_size; | |||
| blob_size >>= 8u; | |||
| } | |||
| std::copy(buf.begin(), buf.end(), (*output)->begin() + idx); | |||
| idx += buf.size(); | |||
| std::copy(blob->begin(), blob->end(), (*output)->begin() + idx); | |||
| idx += blob->size(); | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| MSRStatus ShardWriter::WriteRawData(std::map<uint64_t, std::vector<json>> &raw_data, | |||
| std::vector<std::vector<uint8_t>> &blob_data, bool sign, bool parallel_writer) { | |||
| @@ -1250,5 +1286,21 @@ void ShardWriter::SetLastBlobPage(const int &shard_id, std::shared_ptr<Page> &la | |||
| last_blob_page = page.first; | |||
| } | |||
| } | |||
| MSRStatus ShardWriter::initialize(const std::unique_ptr<ShardWriter> *writer_ptr, | |||
| const std::vector<std::string> &file_names) { | |||
| if (nullptr == writer_ptr) { | |||
| MS_LOG(ERROR) << "ShardWriter pointer is NULL."; | |||
| return FAILED; | |||
| } | |||
| auto res = (*writer_ptr)->Open(file_names, false); | |||
| if (SUCCESS != res) { | |||
| MS_LOG(ERROR) << "Failed to open mindrecord files to writer."; | |||
| return FAILED; | |||
| } | |||
| (*writer_ptr)->SetHeaderSize(1 << 24); | |||
| (*writer_ptr)->SetPageSize(1 << 25); | |||
| return SUCCESS; | |||
| } | |||
| } // namespace mindrecord | |||
| } // namespace mindspore | |||
| @@ -721,5 +721,35 @@ MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) { | |||
| page_in_handle.close(); | |||
| return SUCCESS; | |||
| } | |||
| MSRStatus ShardHeader::initialize(const std::shared_ptr<ShardHeader> *header_ptr, const json &schema, | |||
| const std::vector<std::string> &index_fields, std::vector<std::string> &blob_fields, | |||
| uint64_t &schema_id) { | |||
| if (nullptr == header_ptr) { | |||
| MS_LOG(ERROR) << "ShardHeader pointer is NULL."; | |||
| return FAILED; | |||
| } | |||
| auto schema_ptr = Schema::Build("mindrecord", schema); | |||
| if (nullptr == schema_ptr) { | |||
| MS_LOG(ERROR) << "Got unexpected error when building mindrecord schema."; | |||
| return FAILED; | |||
| } | |||
| schema_id = (*header_ptr)->AddSchema(schema_ptr); | |||
| // create index | |||
| std::vector<std::pair<uint64_t, std::string>> id_index_fields; | |||
| if (!index_fields.empty()) { | |||
| for (auto &el : index_fields) { | |||
| id_index_fields.emplace_back(schema_id, el); | |||
| } | |||
| if (SUCCESS != (*header_ptr)->AddIndexFields(id_index_fields)) { | |||
| MS_LOG(ERROR) << "Got unexpected error when adding mindrecord index."; | |||
| return FAILED; | |||
| } | |||
| } | |||
| auto build_schema_ptr = (*header_ptr)->GetSchemas()[0]; | |||
| blob_fields = build_schema_ptr->GetBlobFields(); | |||
| return SUCCESS; | |||
| } | |||
| } // namespace mindrecord | |||
| } // namespace mindspore | |||
| @@ -38,13 +38,13 @@ from mindspore._c_expression import typing | |||
| from mindspore import log as logger | |||
| from . import samplers | |||
| from .iterators import DictIterator, TupleIterator, DummyIterator | |||
| from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp | |||
| from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \ | |||
| check_rename, check_numpyslicesdataset, \ | |||
| check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ | |||
| check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \ | |||
| check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \ | |||
| check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_positive_int32 | |||
| check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_positive_int32, check_save | |||
| from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist | |||
| try: | |||
| @@ -1044,6 +1044,34 @@ class Dataset: | |||
| return TransferDataset(self, queue_name, device_id, device_type, num_batch) | |||
| @check_save | |||
| def save(self, file_name, num_files=1, file_type='mindrecord'): | |||
| """ | |||
| Save the dynamic data processed by dataset pipeline as common dataset format, support: mindrecord. | |||
| Note: | |||
| 1. To save the samples in order, should set dataset's shuffle false and num_files 1. | |||
| 2. Before call the function, do not use batch, repeat operator or data augmentation operators | |||
| with random attribute in map operator. | |||
| 3. Mindreocrd do not support np.uint64, multi-dimensional np.uint8(drop dimension) and | |||
| multi-dimensional string. | |||
| Args: | |||
| file_name (str): Path to dataset file. | |||
| num_files (int, optional): Number of dataset files.(default=1). | |||
| file_type (str, optional): dataset format.(default='mindrecord') | |||
| """ | |||
| if num_files == 1: | |||
| file_names = [file_name] | |||
| else: | |||
| suffix = len(str(num_files - 1)) | |||
| file_names = ["{}{}".format(file_name, str(x).rjust(suffix, '0')) | |||
| for x in range(num_files)] | |||
| return SaveOp(self).save(file_names, file_type) | |||
| def create_tuple_iterator(self, columns=None): | |||
| """ | |||
| Create an Iterator over the dataset. The data retrieved will be a list of ndarray of data. | |||
| @@ -173,6 +173,7 @@ class Iterator: | |||
| # Convert python node into C node and add to C layer execution tree in postorder traversal. | |||
| def __convert_node_postorder(self, node): | |||
| self.check_node_type(node) | |||
| op_type = self.__get_dataset_type(node) | |||
| c_nodes = self.depipeline.AddNodeToTree(op_type, node.get_args()) | |||
| @@ -224,6 +225,10 @@ class Iterator: | |||
| self._index += 1 | |||
| return data | |||
| @abstractmethod | |||
| def check_node_type(self, node): | |||
| pass | |||
| def get_output_shapes(self): | |||
| return [t for t in self.depipeline.GetOutputShapes()] | |||
| @@ -245,11 +250,27 @@ class Iterator: | |||
| def __deepcopy__(self, memo): | |||
| return self | |||
| class SaveOp(Iterator): | |||
| """ | |||
| The derived class of Iterator with dict type. | |||
| """ | |||
| def get_next(self): | |||
| pass | |||
| def check_node_type(self, node): | |||
| if isinstance(node, (de.ShuffleDataset, de.RepeatDataset, de.BatchDataset)): | |||
| logger.warning("Used shuffle, repeat, batch before save operator.") | |||
| def save(self, file_names, file_type): | |||
| return self.depipeline.SaveDataset(file_names, file_type) | |||
| class DictIterator(Iterator): | |||
| """ | |||
| The derived class of Iterator with dict type. | |||
| """ | |||
| def check_node_type(self, node): | |||
| pass | |||
| def __iter__(self): | |||
| return self | |||
| @@ -269,6 +290,8 @@ class TupleIterator(Iterator): | |||
| """ | |||
| The derived class of Iterator with list type. | |||
| """ | |||
| def check_node_type(self, node): | |||
| pass | |||
| def __init__(self, dataset, columns=None): | |||
| if columns is not None: | |||
| @@ -246,7 +246,24 @@ def check_celebadataset(method): | |||
| return new_method | |||
| def check_save(method): | |||
| """A wrapper that wrap a parameter checker to the save op.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| _, param_dict = parse_user_args(method, *args, **kwargs) | |||
| nreq_param_int = ['num_files'] | |||
| nreq_param_str = ['file_name', 'file_type'] | |||
| validate_dataset_param_value(nreq_param_int, param_dict, int) | |||
| if(param_dict.get('num_files') <= 0 or param_dict.get('num_files') > 1000): | |||
| raise ValueError("num_files should between {} and {}.".format(1, 1000)) | |||
| validate_dataset_param_value(nreq_param_str, param_dict, str) | |||
| if param_dict.get('file_type') != 'mindrecord': | |||
| raise ValueError("{} dataset format is not supported.".format(param_dict.get('file_type'))) | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| def check_minddataset(method): | |||
| """A wrapper that wraps a parameter checker to the original Dataset(MindDataset).""" | |||
| @@ -0,0 +1,390 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| This is the test module for saveOp. | |||
| """ | |||
| import os | |||
| import mindspore.dataset as ds | |||
| from mindspore import log as logger | |||
| from mindspore.mindrecord import FileWriter | |||
| import numpy as np | |||
| import pytest | |||
| CV_FILE_NAME1 = "../data/mindrecord/testMindDataSet/temp.mindrecord" | |||
| CV_FILE_NAME2 = "../data/mindrecord/testMindDataSet/auto.mindrecord" | |||
| FILES_NUM = 1 | |||
| num_readers = 1 | |||
| @pytest.fixture(name="add_and_remove_cv_file") | |||
| def fixture_remove(): | |||
| """add/remove cv file""" | |||
| if os.path.exists("{}".format(CV_FILE_NAME1)): | |||
| os.remove("{}".format(CV_FILE_NAME1)) | |||
| if os.path.exists("{}.db".format(CV_FILE_NAME1)): | |||
| os.remove("{}.db".format(CV_FILE_NAME1)) | |||
| if os.path.exists("{}".format(CV_FILE_NAME2)): | |||
| os.remove("{}".format(CV_FILE_NAME2)) | |||
| if os.path.exists("{}.db".format(CV_FILE_NAME2)): | |||
| os.remove("{}.db".format(CV_FILE_NAME2)) | |||
| yield "yield_cv_data" | |||
| if os.path.exists("{}".format(CV_FILE_NAME1)): | |||
| os.remove("{}".format(CV_FILE_NAME1)) | |||
| if os.path.exists("{}.db".format(CV_FILE_NAME1)): | |||
| os.remove("{}.db".format(CV_FILE_NAME1)) | |||
| if os.path.exists("{}".format(CV_FILE_NAME2)): | |||
| os.remove("{}".format(CV_FILE_NAME2)) | |||
| if os.path.exists("{}.db".format(CV_FILE_NAME2)): | |||
| os.remove("{}.db".format(CV_FILE_NAME2)) | |||
| def test_case_00(add_and_remove_cv_file): # only bin data | |||
| data = [{"image1": bytes("image1 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image1 bytes def", encoding='UTF-8'), | |||
| "image3": bytes("image1 bytes ghi", encoding='UTF-8'), | |||
| "image4": bytes("image1 bytes jkl", encoding='UTF-8'), | |||
| "image5": bytes("image1 bytes mno", encoding='UTF-8')}, | |||
| {"image1": bytes("image2 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image2 bytes def", encoding='UTF-8'), | |||
| "image3": bytes("image2 bytes ghi", encoding='UTF-8'), | |||
| "image4": bytes("image2 bytes jkl", encoding='UTF-8'), | |||
| "image5": bytes("image2 bytes mno", encoding='UTF-8')}, | |||
| {"image1": bytes("image3 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image3 bytes def", encoding='UTF-8'), | |||
| "image3": bytes("image3 bytes ghi", encoding='UTF-8'), | |||
| "image4": bytes("image3 bytes jkl", encoding='UTF-8'), | |||
| "image5": bytes("image3 bytes mno", encoding='UTF-8')}, | |||
| {"image1": bytes("image5 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image5 bytes def", encoding='UTF-8'), | |||
| "image3": bytes("image5 bytes ghi", encoding='UTF-8'), | |||
| "image4": bytes("image5 bytes jkl", encoding='UTF-8'), | |||
| "image5": bytes("image5 bytes mno", encoding='UTF-8')}, | |||
| {"image1": bytes("image6 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image6 bytes def", encoding='UTF-8'), | |||
| "image3": bytes("image6 bytes ghi", encoding='UTF-8'), | |||
| "image4": bytes("image6 bytes jkl", encoding='UTF-8'), | |||
| "image5": bytes("image6 bytes mno", encoding='UTF-8')}] | |||
| schema = { | |||
| "image1": {"type": "bytes"}, | |||
| "image2": {"type": "bytes"}, | |||
| "image3": {"type": "bytes"}, | |||
| "image4": {"type": "bytes"}, | |||
| "image5": {"type": "bytes"}} | |||
| writer = FileWriter(CV_FILE_NAME1, FILES_NUM) | |||
| writer.add_schema(schema, "schema") | |||
| writer.write_raw_data(data) | |||
| writer.commit() | |||
| d1 = ds.MindDataset(CV_FILE_NAME1, None, num_readers, shuffle=False) | |||
| d1.save(CV_FILE_NAME2, FILES_NUM) | |||
| data_value_to_list = [] | |||
| for item in data: | |||
| new_data = {} | |||
| new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8) | |||
| new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8) | |||
| new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8) | |||
| new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8) | |||
| new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8) | |||
| data_value_to_list.append(new_data) | |||
| d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2, | |||
| num_parallel_workers=num_readers, | |||
| shuffle=False) | |||
| assert d2.get_dataset_size() == 5 | |||
| num_iter = 0 | |||
| for item in d2.create_dict_iterator(): | |||
| assert len(item) == 5 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == | |||
| data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| assert num_iter == 5 | |||
| def test_case_01(add_and_remove_cv_file): # only raw data | |||
| data = [{"file_name": "001.jpg", "label": 43}, | |||
| {"file_name": "002.jpg", "label": 91}, | |||
| {"file_name": "003.jpg", "label": 61}, | |||
| {"file_name": "004.jpg", "label": 29}, | |||
| {"file_name": "005.jpg", "label": 78}, | |||
| {"file_name": "006.jpg", "label": 37}] | |||
| schema = {"file_name": {"type": "string"}, | |||
| "label": {"type": "int32"} | |||
| } | |||
| writer = FileWriter(CV_FILE_NAME1, FILES_NUM) | |||
| writer.add_schema(schema, "schema") | |||
| writer.write_raw_data(data) | |||
| writer.commit() | |||
| d1 = ds.MindDataset(CV_FILE_NAME1, None, num_readers, shuffle=False) | |||
| d1.save(CV_FILE_NAME2, FILES_NUM) | |||
| data_value_to_list = [] | |||
| for item in data: | |||
| new_data = {} | |||
| new_data['file_name'] = np.asarray(item["file_name"], dtype='S') | |||
| new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32) | |||
| data_value_to_list.append(new_data) | |||
| d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2, | |||
| num_parallel_workers=num_readers, | |||
| shuffle=False) | |||
| assert d2.get_dataset_size() == 6 | |||
| num_iter = 0 | |||
| for item in d2.create_dict_iterator(): | |||
| logger.info(item) | |||
| assert len(item) == 2 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == | |||
| data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| assert num_iter == 6 | |||
| def test_case_02(add_and_remove_cv_file): # muti-bytes | |||
| data = [{"file_name": "001.jpg", "label": 43, | |||
| "float32_array": np.array([1.2, 2.78, 3.1234, 4.9871, 5.12341], dtype=np.float32), | |||
| "float64_array": np.array([48.1234556789, 49.3251241431, 50.13514312414, 51.8971298471, | |||
| 123414314.2141243, 87.1212122], dtype=np.float64), | |||
| "float32": 3456.12345, | |||
| "float64": 1987654321.123456785, | |||
| "source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int32), | |||
| "source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64), | |||
| "image1": bytes("image1 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image1 bytes def", encoding='UTF-8'), | |||
| "image3": bytes("image1 bytes ghi", encoding='UTF-8'), | |||
| "image4": bytes("image1 bytes jkl", encoding='UTF-8'), | |||
| "image5": bytes("image1 bytes mno", encoding='UTF-8')}, | |||
| {"file_name": "002.jpg", "label": 91, | |||
| "float32_array": np.array([1.2, 2.78, 4.1234, 4.9871, 5.12341], dtype=np.float32), | |||
| "float64_array": np.array([48.1234556789, 49.3251241431, 60.13514312414, 51.8971298471, | |||
| 123414314.2141243, 87.1212122], dtype=np.float64), | |||
| "float32": 3456.12445, | |||
| "float64": 1987654321.123456786, | |||
| "source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int32), | |||
| "source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64), | |||
| "image1": bytes("image2 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image2 bytes def", encoding='UTF-8'), | |||
| "image3": bytes("image2 bytes ghi", encoding='UTF-8'), | |||
| "image4": bytes("image2 bytes jkl", encoding='UTF-8'), | |||
| "image5": bytes("image2 bytes mno", encoding='UTF-8')}, | |||
| {"file_name": "003.jpg", "label": 61, | |||
| "float32_array": np.array([1.2, 2.78, 5.1234, 4.9871, 5.12341], dtype=np.float32), | |||
| "float64_array": np.array([48.1234556789, 49.3251241431, 70.13514312414, 51.8971298471, | |||
| 123414314.2141243, 87.1212122], dtype=np.float64), | |||
| "float32": 3456.12545, | |||
| "float64": 1987654321.123456787, | |||
| "source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int32), | |||
| "source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64), | |||
| "image1": bytes("image3 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image3 bytes def", encoding='UTF-8'), | |||
| "image3": bytes("image3 bytes ghi", encoding='UTF-8'), | |||
| "image4": bytes("image3 bytes jkl", encoding='UTF-8'), | |||
| "image5": bytes("image3 bytes mno", encoding='UTF-8')}, | |||
| {"file_name": "004.jpg", "label": 29, | |||
| "float32_array": np.array([1.2, 2.78, 6.1234, 4.9871, 5.12341], dtype=np.float32), | |||
| "float64_array": np.array([48.1234556789, 49.3251241431, 80.13514312414, 51.8971298471, | |||
| 123414314.2141243, 87.1212122], dtype=np.float64), | |||
| "float32": 3456.12645, | |||
| "float64": 1987654321.123456788, | |||
| "source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int32), | |||
| "source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64), | |||
| "image1": bytes("image4 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image4 bytes def", encoding='UTF-8'), | |||
| "image3": bytes("image4 bytes ghi", encoding='UTF-8'), | |||
| "image4": bytes("image4 bytes jkl", encoding='UTF-8'), | |||
| "image5": bytes("image4 bytes mno", encoding='UTF-8')}, | |||
| {"file_name": "005.jpg", "label": 78, | |||
| "float32_array": np.array([1.2, 2.78, 7.1234, 4.9871, 5.12341], dtype=np.float32), | |||
| "float64_array": np.array([48.1234556789, 49.3251241431, 90.13514312414, 51.8971298471, | |||
| 123414314.2141243, 87.1212122], dtype=np.float64), | |||
| "float32": 3456.12745, | |||
| "float64": 1987654321.123456789, | |||
| "source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int32), | |||
| "source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64), | |||
| "image1": bytes("image5 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image5 bytes def", encoding='UTF-8'), | |||
| "image3": bytes("image5 bytes ghi", encoding='UTF-8'), | |||
| "image4": bytes("image5 bytes jkl", encoding='UTF-8'), | |||
| "image5": bytes("image5 bytes mno", encoding='UTF-8')}, | |||
| {"file_name": "006.jpg", "label": 37, | |||
| "float32_array": np.array([1.2, 2.78, 7.1234, 4.9871, 5.12341], dtype=np.float32), | |||
| "float64_array": np.array([48.1234556789, 49.3251241431, 90.13514312414, 51.8971298471, | |||
| 123414314.2141243, 87.1212122], dtype=np.float64), | |||
| "float32": 3456.12745, | |||
| "float64": 1987654321.123456789, | |||
| "source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int32), | |||
| "source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64), | |||
| "image1": bytes("image6 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image6 bytes def", encoding='UTF-8'), | |||
| "image3": bytes("image6 bytes ghi", encoding='UTF-8'), | |||
| "image4": bytes("image6 bytes jkl", encoding='UTF-8'), | |||
| "image5": bytes("image6 bytes mno", encoding='UTF-8')} | |||
| ] | |||
| schema = {"file_name": {"type": "string"}, | |||
| "float32_array": {"type": "float32", "shape": [-1]}, | |||
| "float64_array": {"type": "float64", "shape": [-1]}, | |||
| "float32": {"type": "float32"}, | |||
| "float64": {"type": "float64"}, | |||
| "source_sos_ids": {"type": "int32", "shape": [-1]}, | |||
| "source_sos_mask": {"type": "int64", "shape": [-1]}, | |||
| "image1": {"type": "bytes"}, | |||
| "image2": {"type": "bytes"}, | |||
| "image3": {"type": "bytes"}, | |||
| "label": {"type": "int32"}, | |||
| "image4": {"type": "bytes"}, | |||
| "image5": {"type": "bytes"}} | |||
| writer = FileWriter(CV_FILE_NAME1, FILES_NUM) | |||
| writer.add_schema(schema, "schema") | |||
| writer.write_raw_data(data) | |||
| writer.commit() | |||
| d1 = ds.MindDataset(CV_FILE_NAME1, None, num_readers, shuffle=False) | |||
| d1.save(CV_FILE_NAME2, FILES_NUM) | |||
| data_value_to_list = [] | |||
| for item in data: | |||
| new_data = {} | |||
| new_data['file_name'] = np.asarray(item["file_name"], dtype='S') | |||
| new_data['float32_array'] = item["float32_array"] | |||
| new_data['float64_array'] = item["float64_array"] | |||
| new_data['float32'] = item["float32"] | |||
| new_data['float64'] = item["float64"] | |||
| new_data['source_sos_ids'] = item["source_sos_ids"] | |||
| new_data['source_sos_mask'] = item["source_sos_mask"] | |||
| new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32) | |||
| new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8) | |||
| new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8) | |||
| new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8) | |||
| new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8) | |||
| new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8) | |||
| data_value_to_list.append(new_data) | |||
| d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2, | |||
| num_parallel_workers=num_readers, | |||
| shuffle=False) | |||
| assert d2.get_dataset_size() == 6 | |||
| num_iter = 0 | |||
| for item in d2.create_dict_iterator(): | |||
| assert len(item) == 13 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| if item[field].dtype == np.float32: | |||
| assert (item[field] == | |||
| np.array(data_value_to_list[num_iter][field], np.float32)).all() | |||
| else: | |||
| assert (item[field] == | |||
| data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| assert num_iter == 6 | |||
| def generator_1d(): | |||
| for i in range(10): | |||
| yield (np.array([i]),) | |||
| def test_case_03(add_and_remove_cv_file): | |||
| # apply dataset operations | |||
| d1 = ds.GeneratorDataset(generator_1d, ["data"], shuffle=False) | |||
| d1.save(CV_FILE_NAME2) | |||
| d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2, | |||
| num_parallel_workers=num_readers, | |||
| shuffle=False) | |||
| i = 0 | |||
| for item in d2.create_dict_iterator(): # each data is a dictionary | |||
| golden = np.array([i]) | |||
| assert np.array_equal(item["data"], golden) | |||
| i = i + 1 | |||
| def generator_with_type(t): | |||
| for i in range(64): | |||
| yield (np.array([i], dtype=t),) | |||
| def type_tester(t): | |||
| logger.info("Test with Type {}".format(t.__name__)) | |||
| # apply dataset operations | |||
| data1 = ds.GeneratorDataset((lambda: generator_with_type(t)), ["data"], shuffle=False) | |||
| data1 = data1.batch(4) | |||
| data1 = data1.repeat(3) | |||
| data1.save(CV_FILE_NAME2) | |||
| d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2, | |||
| num_parallel_workers=num_readers, | |||
| shuffle=False) | |||
| i = 0 | |||
| num_repeat = 0 | |||
| for item in d2.create_dict_iterator(): # each data is a dictionary | |||
| golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t) | |||
| logger.info(item) | |||
| assert np.array_equal(item["data"], golden) | |||
| i = i + 4 | |||
| if i == 64: | |||
| i = 0 | |||
| num_repeat += 1 | |||
| assert num_repeat == 3 | |||
| if os.path.exists("{}".format(CV_FILE_NAME2)): | |||
| os.remove("{}".format(CV_FILE_NAME2)) | |||
| if os.path.exists("{}.db".format(CV_FILE_NAME2)): | |||
| os.remove("{}.db".format(CV_FILE_NAME2)) | |||
| def test_case_04(): | |||
| # uint8 will drop shape as mindrecord store uint8 as bytes | |||
| types = [np.int8, np.int16, np.int32, np.int64, | |||
| np.uint16, np.uint32, np.float32, np.float64] | |||
| for t in types: | |||
| type_tester(t) | |||
| def test_case_05(add_and_remove_cv_file): | |||
| d1 = ds.GeneratorDataset(generator_1d, ["data"], shuffle=False) | |||
| with pytest.raises(Exception, match="num_files should between 1 and 1000."): | |||
| d1.save(CV_FILE_NAME2, 0) | |||
| def test_case_06(add_and_remove_cv_file): | |||
| d1 = ds.GeneratorDataset(generator_1d, ["data"], shuffle=False) | |||
| with pytest.raises(Exception, match="tfrecord dataset format is not supported."): | |||
| d1.save(CV_FILE_NAME2, 1, "tfrecord") | |||