Merge pull request !1317 from liyong126/mindrecord_compresstags/v0.3.0-alpha
| @@ -112,25 +112,26 @@ Status MindRecordOp::Init() { | |||||
| data_schema_ = std::make_unique<DataSchema>(); | data_schema_ = std::make_unique<DataSchema>(); | ||||
| std::vector<std::shared_ptr<Schema>> schema_vec = shard_reader_->GetShardHeader()->GetSchemas(); | |||||
| // check whether schema exists, if so use the first one | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!schema_vec.empty(), "No schema found"); | |||||
| mindrecord::json mr_schema = schema_vec[0]->GetSchema()["schema"]; | |||||
| std::vector<std::string> col_names = shard_reader_->get_shard_column()->GetColumnName(); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!col_names.empty(), "No schema found"); | |||||
| std::vector<mindrecord::ColumnDataType> col_data_types = shard_reader_->get_shard_column()->GeColumnDataType(); | |||||
| std::vector<std::vector<int64_t>> col_shapes = shard_reader_->get_shard_column()->GetColumnShape(); | |||||
| bool load_all_cols = columns_to_load_.empty(); // if columns_to_load_ is empty it means load everything | bool load_all_cols = columns_to_load_.empty(); // if columns_to_load_ is empty it means load everything | ||||
| std::map<std::string, int32_t> colname_to_ind; | std::map<std::string, int32_t> colname_to_ind; | ||||
| for (mindrecord::json::iterator it = mr_schema.begin(); it != mr_schema.end(); ++it) { | |||||
| std::string colname = it.key(); // key of the json, column name | |||||
| mindrecord::json it_value = it.value(); // value, which contains type info and may contain shape | |||||
| for (uint32_t i = 0; i < col_names.size(); i++) { | |||||
| std::string colname = col_names[i]; | |||||
| ColDescriptor col_desc; | ColDescriptor col_desc; | ||||
| TensorShape t_shape = TensorShape::CreateUnknownRankShape(); // shape of tensor, default unknown | TensorShape t_shape = TensorShape::CreateUnknownRankShape(); // shape of tensor, default unknown | ||||
| std::string type_str = (it_value["type"] == "bytes" || it_value["type"] == "string") ? "uint8" : it_value["type"]; | |||||
| std::string type_str = mindrecord::ColumnDataTypeNameNormalized[col_data_types[i]]; | |||||
| DataType t_dtype = DataType(type_str); // valid types: {"bytes", "string", "int32", "int64", "float32", "float64"} | DataType t_dtype = DataType(type_str); // valid types: {"bytes", "string", "int32", "int64", "float32", "float64"} | ||||
| if (it_value["type"] == "bytes") { // rank = 1 | |||||
| if (col_data_types[i] == mindrecord::ColumnBytes || col_data_types[i] == mindrecord::ColumnString) { // rank = 1 | |||||
| col_desc = ColDescriptor(colname, t_dtype, TensorImpl::kFlexible, 1); | col_desc = ColDescriptor(colname, t_dtype, TensorImpl::kFlexible, 1); | ||||
| } else if (it_value.find("shape") != it_value.end()) { | |||||
| std::vector<dsize_t> vec(it_value["shape"].size()); // temporary vector to hold shape | |||||
| (void)std::copy(it_value["shape"].begin(), it_value["shape"].end(), vec.begin()); | |||||
| } else if (col_shapes[i].size() > 0) { | |||||
| std::vector<dsize_t> vec(col_shapes[i].size()); // temporary vector to hold shape | |||||
| (void)std::copy(col_shapes[i].begin(), col_shapes[i].end(), vec.begin()); | |||||
| t_shape = TensorShape(vec); | t_shape = TensorShape(vec); | ||||
| col_desc = ColDescriptor(colname, t_dtype, TensorImpl::kFlexible, t_shape.Rank(), &t_shape); | col_desc = ColDescriptor(colname, t_dtype, TensorImpl::kFlexible, t_shape.Rank(), &t_shape); | ||||
| } else { // unknown shape | } else { // unknown shape | ||||
| @@ -162,30 +163,7 @@ Status MindRecordOp::Init() { | |||||
| num_rows_ = shard_reader_->GetNumRows(); | num_rows_ = shard_reader_->GetNumRows(); | ||||
| // Compute how many buffers we would need to accomplish rowsPerBuffer | // Compute how many buffers we would need to accomplish rowsPerBuffer | ||||
| buffers_needed_ = (num_rows_ + rows_per_buffer_ - 1) / rows_per_buffer_; | buffers_needed_ = (num_rows_ + rows_per_buffer_ - 1) / rows_per_buffer_; | ||||
| RETURN_IF_NOT_OK(SetColumnsBlob()); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status MindRecordOp::SetColumnsBlob() { | |||||
| columns_blob_ = shard_reader_->GetBlobFields().second; | |||||
| // get the exactly blob fields by columns_to_load_ | |||||
| std::vector<std::string> columns_blob_exact; | |||||
| for (auto &blob_field : columns_blob_) { | |||||
| for (auto &column : columns_to_load_) { | |||||
| if (column.compare(blob_field) == 0) { | |||||
| columns_blob_exact.push_back(blob_field); | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| columns_blob_index_ = std::vector<int32_t>(columns_to_load_.size(), -1); | |||||
| int32_t iBlob = 0; | |||||
| for (auto &blob_exact : columns_blob_exact) { | |||||
| columns_blob_index_[column_name_id_map_[blob_exact]] = iBlob++; | |||||
| } | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -215,248 +193,18 @@ void MindRecordOp::Print(std::ostream &out, bool show_all) const { | |||||
| } | } | ||||
| } | } | ||||
| template <typename T> | |||||
| Status MindRecordOp::LoadFeature(std::shared_ptr<Tensor> *tensor, int32_t i_col, | |||||
| const std::vector<uint8_t> &columns_blob, const mindrecord::json &columns_json) const { | |||||
| TensorShape new_shape = TensorShape::CreateUnknownRankShape(); | |||||
| const unsigned char *data = nullptr; | |||||
| std::unique_ptr<T[]> array_data; | |||||
| std::string string_data; | |||||
| const ColDescriptor &cur_column = data_schema_->column(i_col); | |||||
| std::string column_name = columns_to_load_[i_col]; | |||||
| DataType type = cur_column.type(); | |||||
| // load blob column | |||||
| if (columns_blob_index_[i_col] >= 0 && columns_blob.size() > 0) { | |||||
| int32_t pos = columns_blob_.size() == 1 ? -1 : columns_blob_index_[i_col]; | |||||
| RETURN_IF_NOT_OK(LoadBlob(&new_shape, &data, columns_blob, pos, cur_column)); | |||||
| } else { | |||||
| switch (type.value()) { | |||||
| case DataType::DE_UINT8: { | |||||
| // For strings (Assume DE_UINT8 is reserved for strings) | |||||
| RETURN_IF_NOT_OK(LoadByte(&new_shape, &string_data, column_name, columns_json)); | |||||
| data = reinterpret_cast<const unsigned char *>(common::SafeCStr(string_data)); | |||||
| break; | |||||
| } | |||||
| case DataType::DE_FLOAT32: { | |||||
| // For both float scalars and arrays | |||||
| RETURN_IF_NOT_OK(LoadFloat(&new_shape, &array_data, column_name, columns_json, cur_column, false)); | |||||
| data = reinterpret_cast<const unsigned char *>(array_data.get()); | |||||
| break; | |||||
| } | |||||
| case DataType::DE_FLOAT64: { | |||||
| // For both double scalars and arrays | |||||
| RETURN_IF_NOT_OK(LoadFloat(&new_shape, &array_data, column_name, columns_json, cur_column, true)); | |||||
| data = reinterpret_cast<const unsigned char *>(array_data.get()); | |||||
| break; | |||||
| } | |||||
| default: { | |||||
| // For both integers scalars and arrays | |||||
| RETURN_IF_NOT_OK(LoadInt(&new_shape, &array_data, column_name, columns_json, cur_column)); | |||||
| data = reinterpret_cast<const unsigned char *>(array_data.get()); | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| // Create Tensor with given details | |||||
| RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, cur_column.tensorImpl(), new_shape, type, data)); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status MindRecordOp::LoadBlob(TensorShape *new_shape, const unsigned char **data, | |||||
| const std::vector<uint8_t> &columns_blob, const int32_t pos, | |||||
| const ColDescriptor &column) { | |||||
| const auto kColumnSize = column.type().SizeInBytes(); | |||||
| if (kColumnSize == 0) { | |||||
| RETURN_STATUS_UNEXPECTED("column size is null"); | |||||
| } | |||||
| if (pos == -1) { | |||||
| if (column.hasShape()) { | |||||
| *new_shape = TensorShape::CreateUnknownRankShape(); | |||||
| RETURN_IF_NOT_OK( | |||||
| column.MaterializeTensorShape(static_cast<int32_t>(columns_blob.size() / kColumnSize), new_shape)); | |||||
| } else { | |||||
| std::vector<dsize_t> shapeDetails = {static_cast<dsize_t>(columns_blob.size() / kColumnSize)}; | |||||
| *new_shape = TensorShape(shapeDetails); | |||||
| } | |||||
| *data = reinterpret_cast<const uint8_t *>(&(columns_blob[0])); | |||||
| return Status::OK(); | |||||
| } | |||||
| auto uint64_from_bytes = [&](int64_t pos) { | |||||
| uint64_t result = 0; | |||||
| for (uint64_t n = 0; n < kInt64Len; n++) { | |||||
| result = (result << 8) + columns_blob[pos + n]; | |||||
| } | |||||
| return result; | |||||
| }; | |||||
| uint64_t iStart = 0; | |||||
| for (int32_t i = 0; i < pos; i++) { | |||||
| uint64_t num_bytes = uint64_from_bytes(iStart); | |||||
| iStart += kInt64Len + num_bytes; | |||||
| } | |||||
| uint64_t num_bytes = uint64_from_bytes(iStart); | |||||
| iStart += kInt64Len; | |||||
| if (column.hasShape()) { | |||||
| *new_shape = TensorShape::CreateUnknownRankShape(); | |||||
| RETURN_IF_NOT_OK(column.MaterializeTensorShape(static_cast<int32_t>(num_bytes / kColumnSize), new_shape)); | |||||
| } else { | |||||
| std::vector<dsize_t> shapeDetails = {static_cast<dsize_t>(num_bytes / kColumnSize)}; | |||||
| *new_shape = TensorShape(shapeDetails); | |||||
| } | |||||
| *data = reinterpret_cast<const uint8_t *>(&(columns_blob[iStart])); | |||||
| return Status::OK(); | |||||
| } | |||||
| template <typename T> | |||||
| Status MindRecordOp::LoadFloat(TensorShape *new_shape, std::unique_ptr<T[]> *array_data, const std::string &column_name, | |||||
| const mindrecord::json &columns_json, const ColDescriptor &column, bool use_double) { | |||||
| if (!columns_json[column_name].is_array()) { | |||||
| T value = 0; | |||||
| RETURN_IF_NOT_OK(GetFloat(&value, columns_json[column_name], use_double)); | |||||
| *new_shape = TensorShape::CreateScalar(); | |||||
| *array_data = std::make_unique<T[]>(1); | |||||
| (*array_data)[0] = value; | |||||
| } else { | |||||
| if (column.hasShape()) { | |||||
| *new_shape = TensorShape(column.shape()); | |||||
| } else { | |||||
| std::vector<dsize_t> shapeDetails = {static_cast<dsize_t>(columns_json[column_name].size())}; | |||||
| *new_shape = TensorShape(shapeDetails); | |||||
| } | |||||
| int idx = 0; | |||||
| *array_data = std::make_unique<T[]>(new_shape->NumOfElements()); | |||||
| for (auto &element : columns_json[column_name]) { | |||||
| T value = 0; | |||||
| RETURN_IF_NOT_OK(GetFloat(&value, element, use_double)); | |||||
| (*array_data)[idx++] = value; | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| template <typename T> | |||||
| Status MindRecordOp::GetFloat(T *value, const mindrecord::json &data, bool use_double) { | |||||
| if (data.is_number()) { | |||||
| *value = data; | |||||
| } else if (data.is_string()) { | |||||
| try { | |||||
| if (use_double) { | |||||
| *value = data.get<double>(); | |||||
| } else { | |||||
| *value = data.get<float>(); | |||||
| } | |||||
| } catch (mindrecord::json::exception &e) { | |||||
| RETURN_STATUS_UNEXPECTED("Conversion to float failed."); | |||||
| } | |||||
| } else { | |||||
| RETURN_STATUS_UNEXPECTED("Conversion to float failed."); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| template <typename T> | |||||
| Status MindRecordOp::LoadInt(TensorShape *new_shape, std::unique_ptr<T[]> *array_data, const std::string &column_name, | |||||
| const mindrecord::json &columns_json, const ColDescriptor &column) { | |||||
| if (!columns_json[column_name].is_array()) { | |||||
| T value = 0; | |||||
| RETURN_IF_NOT_OK(GetInt(&value, columns_json[column_name])); | |||||
| *new_shape = TensorShape::CreateScalar(); | |||||
| *array_data = std::make_unique<T[]>(1); | |||||
| (*array_data)[0] = value; | |||||
| } else { | |||||
| if (column.hasShape()) { | |||||
| *new_shape = TensorShape(column.shape()); | |||||
| } else { | |||||
| std::vector<dsize_t> shapeDetails = {static_cast<dsize_t>(columns_json[column_name].size())}; | |||||
| *new_shape = TensorShape(shapeDetails); | |||||
| } | |||||
| int idx = 0; | |||||
| *array_data = std::make_unique<T[]>(new_shape->NumOfElements()); | |||||
| for (auto &element : columns_json[column_name]) { | |||||
| T value = 0; | |||||
| RETURN_IF_NOT_OK(GetInt(&value, element)); | |||||
| (*array_data)[idx++] = value; | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| template <typename T> | |||||
| Status MindRecordOp::GetInt(T *value, const mindrecord::json &data) { | |||||
| int64_t temp_value = 0; | |||||
| bool less_than_zero = false; | |||||
| if (data.is_number_integer()) { | |||||
| const mindrecord::json json_zero = 0; | |||||
| if (data < json_zero) less_than_zero = true; | |||||
| temp_value = data; | |||||
| } else if (data.is_string()) { | |||||
| std::string string_value = data; | |||||
| if (!string_value.empty() && string_value[0] == '-') { | |||||
| try { | |||||
| temp_value = std::stoll(string_value); | |||||
| less_than_zero = true; | |||||
| } catch (std::invalid_argument &e) { | |||||
| RETURN_STATUS_UNEXPECTED("Conversion to int failed, invalid argument."); | |||||
| } catch (std::out_of_range &e) { | |||||
| RETURN_STATUS_UNEXPECTED("Conversion to int failed, out of range."); | |||||
| } | |||||
| } else { | |||||
| try { | |||||
| temp_value = static_cast<int64_t>(std::stoull(string_value)); | |||||
| } catch (std::invalid_argument &e) { | |||||
| RETURN_STATUS_UNEXPECTED("Conversion to int failed, invalid argument."); | |||||
| } catch (std::out_of_range &e) { | |||||
| RETURN_STATUS_UNEXPECTED("Conversion to int failed, out of range."); | |||||
| } | |||||
| } | |||||
| } else { | |||||
| RETURN_STATUS_UNEXPECTED("Conversion to int failed."); | |||||
| } | |||||
| if ((less_than_zero && temp_value < static_cast<int64_t>(std::numeric_limits<T>::min())) || | |||||
| (!less_than_zero && static_cast<uint64_t>(temp_value) > static_cast<uint64_t>(std::numeric_limits<T>::max()))) { | |||||
| RETURN_STATUS_UNEXPECTED("Conversion to int failed. Out of range"); | |||||
| } | |||||
| *value = static_cast<T>(temp_value); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status MindRecordOp::LoadByte(TensorShape *new_shape, std::string *string_data, const std::string &column_name, | |||||
| const mindrecord::json &columns_json) { | |||||
| *string_data = columns_json[column_name]; | |||||
| std::vector<dsize_t> shape_details = {static_cast<dsize_t>(string_data->size())}; | |||||
| *new_shape = TensorShape(shape_details); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status MindRecordOp::WorkerEntry(int32_t worker_id) { | Status MindRecordOp::WorkerEntry(int32_t worker_id) { | ||||
| TaskManager::FindMe()->Post(); | TaskManager::FindMe()->Post(); | ||||
| std::unique_ptr<IOBlock> io_block; | std::unique_ptr<IOBlock> io_block; | ||||
| RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); | RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); | ||||
| while (io_block != nullptr) { | while (io_block != nullptr) { | ||||
| if (io_block->eoe() == true) { | |||||
| if (io_block->eoe()) { | |||||
| RETURN_IF_NOT_OK( | RETURN_IF_NOT_OK( | ||||
| out_connector_->Add(worker_id, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE)))); | out_connector_->Add(worker_id, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE)))); | ||||
| RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); | RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (io_block->eof() == true) { | |||||
| if (io_block->eof()) { | |||||
| RETURN_IF_NOT_OK( | RETURN_IF_NOT_OK( | ||||
| out_connector_->Add(worker_id, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF)))); | out_connector_->Add(worker_id, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF)))); | ||||
| RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); | RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); | ||||
| @@ -521,19 +269,10 @@ Status MindRecordOp::GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_bu | |||||
| if (tupled_buffer.empty()) break; | if (tupled_buffer.empty()) break; | ||||
| } | } | ||||
| for (const auto &tupled_row : tupled_buffer) { | for (const auto &tupled_row : tupled_buffer) { | ||||
| std::vector<uint8_t> columnsBlob = std::get<0>(tupled_row); | |||||
| std::vector<uint8_t> columns_blob = std::get<0>(tupled_row); | |||||
| mindrecord::json columns_json = std::get<1>(tupled_row); | mindrecord::json columns_json = std::get<1>(tupled_row); | ||||
| TensorRow tensor_row; | TensorRow tensor_row; | ||||
| for (uint32_t j = 0; j < columns_to_load_.size(); ++j) { | |||||
| std::shared_ptr<Tensor> tensor; | |||||
| const ColDescriptor &cur_column = data_schema_->column(j); | |||||
| DataType type = cur_column.type(); | |||||
| RETURN_IF_NOT_OK(SwitchLoadFeature(type, &tensor, j, columnsBlob, columns_json)); | |||||
| tensor_row.push_back(std::move(tensor)); | |||||
| } | |||||
| RETURN_IF_NOT_OK(LoadTensorRow(&tensor_row, columns_blob, columns_json)); | |||||
| tensor_table->push_back(std::move(tensor_row)); | tensor_table->push_back(std::move(tensor_row)); | ||||
| } | } | ||||
| } | } | ||||
| @@ -543,48 +282,46 @@ Status MindRecordOp::GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_bu | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status MindRecordOp::SwitchLoadFeature(const DataType &type, std::shared_ptr<Tensor> *tensor, int32_t i_col, | |||||
| const std::vector<uint8_t> &columns_blob, | |||||
| const mindrecord::json &columns_json) const { | |||||
| switch (type.value()) { | |||||
| case DataType::DE_BOOL: { | |||||
| return LoadFeature<bool>(tensor, i_col, columns_blob, columns_json); | |||||
| } | |||||
| case DataType::DE_INT8: { | |||||
| return LoadFeature<int8_t>(tensor, i_col, columns_blob, columns_json); | |||||
| } | |||||
| case DataType::DE_UINT8: { | |||||
| return LoadFeature<uint8_t>(tensor, i_col, columns_blob, columns_json); | |||||
| } | |||||
| case DataType::DE_INT16: { | |||||
| return LoadFeature<int16_t>(tensor, i_col, columns_blob, columns_json); | |||||
| } | |||||
| case DataType::DE_UINT16: { | |||||
| return LoadFeature<uint16_t>(tensor, i_col, columns_blob, columns_json); | |||||
| } | |||||
| case DataType::DE_INT32: { | |||||
| return LoadFeature<int32_t>(tensor, i_col, columns_blob, columns_json); | |||||
| } | |||||
| case DataType::DE_UINT32: { | |||||
| return LoadFeature<uint32_t>(tensor, i_col, columns_blob, columns_json); | |||||
| } | |||||
| case DataType::DE_INT64: { | |||||
| return LoadFeature<int64_t>(tensor, i_col, columns_blob, columns_json); | |||||
| } | |||||
| case DataType::DE_UINT64: { | |||||
| return LoadFeature<uint64_t>(tensor, i_col, columns_blob, columns_json); | |||||
| } | |||||
| case DataType::DE_FLOAT32: { | |||||
| return LoadFeature<float>(tensor, i_col, columns_blob, columns_json); | |||||
| } | |||||
| case DataType::DE_FLOAT64: { | |||||
| return LoadFeature<double>(tensor, i_col, columns_blob, columns_json); | |||||
| Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vector<uint8_t> &columns_blob, | |||||
| const mindrecord::json &columns_json) { | |||||
| for (uint32_t i_col = 0; i_col < columns_to_load_.size(); i_col++) { | |||||
| auto column_name = columns_to_load_[i_col]; | |||||
| // Initialize column parameters | |||||
| const unsigned char *data = nullptr; | |||||
| std::unique_ptr<unsigned char[]> data_ptr; | |||||
| uint64_t n_bytes = 0; | |||||
| mindrecord::ColumnDataType column_data_type = mindrecord::ColumnNoDataType; | |||||
| uint64_t column_data_type_size = 1; | |||||
| std::vector<int64_t> column_shape; | |||||
| // Get column data | |||||
| auto has_column = shard_reader_->get_shard_column()->GetColumnValueByName( | |||||
| column_name, columns_blob, columns_json, &data, &data_ptr, &n_bytes, &column_data_type, &column_data_type_size, | |||||
| &column_shape); | |||||
| if (has_column == MSRStatus::FAILED) { | |||||
| RETURN_STATUS_UNEXPECTED("Failed to retrieve data from mindrecord reader."); | |||||
| } | } | ||||
| default: { | |||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, | |||||
| "mindrecord column list type does not match any known types"); | |||||
| std::shared_ptr<Tensor> tensor; | |||||
| const ColDescriptor &column = data_schema_->column(i_col); | |||||
| DataType type = column.type(); | |||||
| // Set shape | |||||
| auto num_elements = n_bytes / column_data_type_size; | |||||
| if (column.hasShape()) { | |||||
| auto new_shape = TensorShape(column.shape()); | |||||
| RETURN_IF_NOT_OK(column.MaterializeTensorShape(static_cast<int32_t>(num_elements), &new_shape)); | |||||
| RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, column.tensorImpl(), new_shape, type, data)); | |||||
| } else { | |||||
| std::vector<dsize_t> shapeDetails = {static_cast<dsize_t>(num_elements)}; | |||||
| auto new_shape = TensorShape(shapeDetails); | |||||
| RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, column.tensorImpl(), new_shape, type, data)); | |||||
| } | } | ||||
| tensor_row->push_back(std::move(tensor)); | |||||
| } | } | ||||
| return Status::OK(); | |||||
| } | } | ||||
| Status MindRecordOp::FetchBlockBuffer(const int32_t &buffer_id) { | Status MindRecordOp::FetchBlockBuffer(const int32_t &buffer_id) { | ||||
| @@ -23,6 +23,7 @@ | |||||
| #include <queue> | #include <queue> | ||||
| #include <string> | #include <string> | ||||
| #include <tuple> | #include <tuple> | ||||
| #include <unordered_map> | |||||
| #include <unordered_set> | #include <unordered_set> | ||||
| #include <vector> | #include <vector> | ||||
| @@ -31,6 +32,7 @@ | |||||
| #include "dataset/engine/datasetops/source/io_block.h" | #include "dataset/engine/datasetops/source/io_block.h" | ||||
| #include "dataset/util/queue.h" | #include "dataset/util/queue.h" | ||||
| #include "dataset/util/status.h" | #include "dataset/util/status.h" | ||||
| #include "mindrecord/include/shard_column.h" | |||||
| #include "mindrecord/include/shard_error.h" | #include "mindrecord/include/shard_error.h" | ||||
| #include "mindrecord/include/shard_reader.h" | #include "mindrecord/include/shard_reader.h" | ||||
| #include "mindrecord/include/common/shard_utils.h" | #include "mindrecord/include/common/shard_utils.h" | ||||
| @@ -193,8 +195,6 @@ class MindRecordOp : public ParallelOp { | |||||
| Status Init(); | Status Init(); | ||||
| Status SetColumnsBlob(); | |||||
| // Base-class override for NodePass visitor acceptor. | // Base-class override for NodePass visitor acceptor. | ||||
| // @param p - Pointer to the NodePass to be accepted. | // @param p - Pointer to the NodePass to be accepted. | ||||
| // @param modified - Whether this node visit modified the pipeline. | // @param modified - Whether this node visit modified the pipeline. | ||||
| @@ -205,56 +205,11 @@ class MindRecordOp : public ParallelOp { | |||||
| Status GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_buffer, int64_t buffer_id, int32_t worker_id); | Status GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_buffer, int64_t buffer_id, int32_t worker_id); | ||||
| // Parses a single cell and puts the data into a tensor | // Parses a single cell and puts the data into a tensor | ||||
| // @param tensor - the tensor to put the parsed data in | |||||
| // @param i_col - the id of column to parse | |||||
| // @param tensor_row - the tensor row to put the parsed data in | |||||
| // @param columns_blob - the blob data received from the reader | // @param columns_blob - the blob data received from the reader | ||||
| // @param columns_json - the data for fields received from the reader | // @param columns_json - the data for fields received from the reader | ||||
| template <typename T> | |||||
| Status LoadFeature(std::shared_ptr<Tensor> *tensor, int32_t i_col, const std::vector<uint8_t> &columns_blob, | |||||
| const mindrecord::json &columns_json) const; | |||||
| Status SwitchLoadFeature(const DataType &type, std::shared_ptr<Tensor> *tensor, int32_t i_col, | |||||
| const std::vector<uint8_t> &columns_blob, const mindrecord::json &columns_json) const; | |||||
| static Status LoadBlob(TensorShape *new_shape, const unsigned char **data, const std::vector<uint8_t> &columns_blob, | |||||
| const int32_t pos, const ColDescriptor &column); | |||||
| // Get shape and data (scalar or array) for tensor to be created (for floats and doubles) | |||||
| // @param new_shape - the shape of tensor to be created. | |||||
| // @param array_data - the array where data should be put in | |||||
| // @param column_name - name of current column to be processed | |||||
| // @param columns_json - the data for fields received from the reader | |||||
| // @param column - description of current column from schema | |||||
| // @param use_double - boolean to choose between float32 and float64 | |||||
| template <typename T> | |||||
| static Status LoadFloat(TensorShape *new_shape, std::unique_ptr<T[]> *array_data, const std::string &column_name, | |||||
| const mindrecord::json &columns_json, const ColDescriptor &column, bool use_double); | |||||
| // Get shape and data (scalar or array) for tensor to be created (for integers) | |||||
| // @param new_shape - the shape of tensor to be created. | |||||
| // @param array_data - the array where data should be put in | |||||
| // @param column_name - name of current column to be processed | |||||
| // @param columns_json - the data for fields received from the reader | |||||
| // @param column - description of current column from schema | |||||
| template <typename T> | |||||
| static Status LoadInt(TensorShape *new_shape, std::unique_ptr<T[]> *array_data, const std::string &column_name, | |||||
| const mindrecord::json &columns_json, const ColDescriptor &column); | |||||
| static Status LoadByte(TensorShape *new_shape, std::string *string_data, const std::string &column_name, | |||||
| const mindrecord::json &columns_json); | |||||
| // Get a single float value from the given json | |||||
| // @param value - the float to put the value in | |||||
| // @param arrayData - the given json containing the float | |||||
| // @param use_double - boolean to choose between float32 and float64 | |||||
| template <typename T> | |||||
| static Status GetFloat(T *value, const mindrecord::json &data, bool use_double); | |||||
| // Get a single integer value from the given json | |||||
| // @param value - the integer to put the value in | |||||
| // @param arrayData - the given json containing the integer | |||||
| template <typename T> | |||||
| static Status GetInt(T *value, const mindrecord::json &data); | |||||
| Status LoadTensorRow(TensorRow *tensor_row, const std::vector<uint8_t> &columns_blob, | |||||
| const mindrecord::json &columns_json); | |||||
| Status FetchBlockBuffer(const int32_t &buffer_id); | Status FetchBlockBuffer(const int32_t &buffer_id); | ||||
| @@ -91,8 +91,8 @@ void BindShardReader(const py::module *m) { | |||||
| .def("launch", &ShardReader::Launch) | .def("launch", &ShardReader::Launch) | ||||
| .def("get_header", &ShardReader::GetShardHeader) | .def("get_header", &ShardReader::GetShardHeader) | ||||
| .def("get_blob_fields", &ShardReader::GetBlobFields) | .def("get_blob_fields", &ShardReader::GetBlobFields) | ||||
| .def("get_next", | |||||
| (std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>(ShardReader::*)()) & ShardReader::GetNextPy) | |||||
| .def("get_next", (std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>>(ShardReader::*)()) & | |||||
| ShardReader::GetNextPy) | |||||
| .def("finish", &ShardReader::Finish) | .def("finish", &ShardReader::Finish) | ||||
| .def("close", &ShardReader::Close); | .def("close", &ShardReader::Close); | ||||
| } | } | ||||
| @@ -65,6 +65,9 @@ const int kUnsignedInt4 = 4; | |||||
| enum LabelCategory { kSchemaLabel, kStatisticsLabel, kIndexLabel }; | enum LabelCategory { kSchemaLabel, kStatisticsLabel, kIndexLabel }; | ||||
| const char kVersion[] = "3.0"; | |||||
| const std::vector<std::string> kSupportedVersion = {"2.0", kVersion}; | |||||
| enum ShardType { | enum ShardType { | ||||
| kNLP = 0, | kNLP = 0, | ||||
| kCV = 1, | kCV = 1, | ||||
| @@ -0,0 +1,163 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDRECORD_INCLUDE_SHARD_COLUMN_H_ | |||||
| #define MINDRECORD_INCLUDE_SHARD_COLUMN_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <unordered_map> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "mindrecord/include/shard_header.h" | |||||
| namespace mindspore { | |||||
| namespace mindrecord { | |||||
| const uint64_t kUnsignedOne = 1; | |||||
| const uint64_t kBitsOfByte = 8; | |||||
| const uint64_t kDataTypeBits = 2; | |||||
| const uint64_t kNumDataOfByte = 4; | |||||
| const uint64_t kBytesOfColumnLen = 4; | |||||
| const uint64_t kDataTypeBitMask = 3; | |||||
| const uint64_t kDataTypes = 6; | |||||
| enum IntegerType { kInt8Type = 0, kInt16Type, kInt32Type, kInt64Type }; | |||||
| enum ColumnCategory { ColumnInRaw, ColumnInBlob, ColumnNotFound }; | |||||
| enum ColumnDataType { | |||||
| ColumnBytes = 0, | |||||
| ColumnString = 1, | |||||
| ColumnInt32 = 2, | |||||
| ColumnInt64 = 3, | |||||
| ColumnFloat32 = 4, | |||||
| ColumnFloat64 = 5, | |||||
| ColumnNoDataType = 6 | |||||
| }; | |||||
| // mapping as {"bytes", "string", "int32", "int64", "float32", "float64"}; | |||||
| const uint32_t ColumnDataTypeSize[kDataTypes] = {1, 1, 4, 8, 4, 8}; | |||||
| const std::vector<std::string> ColumnDataTypeNameNormalized = {"uint8", "uint8", "int32", | |||||
| "int64", "float32", "float64"}; | |||||
| const std::unordered_map<std::string, ColumnDataType> ColumnDataTypeMap = { | |||||
| {"bytes", ColumnBytes}, {"string", ColumnString}, {"int32", ColumnInt32}, | |||||
| {"int64", ColumnInt64}, {"float32", ColumnFloat32}, {"float64", ColumnFloat64}}; | |||||
| class ShardColumn { | |||||
| public: | |||||
| explicit ShardColumn(const std::shared_ptr<ShardHeader> &shard_header, bool compress_integer = true); | |||||
| ~ShardColumn() = default; | |||||
| /// \brief get column value by column name | |||||
| MSRStatus GetColumnValueByName(const std::string &column_name, const std::vector<uint8_t> &columns_blob, | |||||
| const json &columns_json, const unsigned char **data, | |||||
| std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes, | |||||
| ColumnDataType *column_data_type, uint64_t *column_data_type_size, | |||||
| std::vector<int64_t> *column_shape); | |||||
| /// \brief compress blob | |||||
| std::vector<uint8_t> CompressBlob(const std::vector<uint8_t> &blob); | |||||
| /// \brief check if blob compressed | |||||
| bool CheckCompressBlob() const { return has_compress_blob_; } | |||||
| uint64_t GetNumBlobColumn() const { return num_blob_column_; } | |||||
| std::vector<std::string> GetColumnName() { return column_name_; } | |||||
| std::vector<ColumnDataType> GeColumnDataType() { return column_data_type_; } | |||||
| std::vector<std::vector<int64_t>> GetColumnShape() { return column_shape_; } | |||||
| /// \brief get column value from blob | |||||
| MSRStatus GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob, | |||||
| const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr, | |||||
| uint64_t *n_bytes); | |||||
| private: | |||||
| /// \brief get column value from json | |||||
| MSRStatus GetColumnFromJson(const std::string &column_name, const json &columns_json, | |||||
| std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes); | |||||
| /// \brief get float value from json | |||||
| template <typename T> | |||||
| MSRStatus GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value, bool use_double); | |||||
| /// \brief get integer value from json | |||||
| template <typename T> | |||||
| MSRStatus GetInt(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value); | |||||
| /// \brief get column offset address and size from blob | |||||
| MSRStatus GetColumnAddressInBlock(const uint64_t &column_id, const std::vector<uint8_t> &columns_blob, | |||||
| uint64_t *num_bytes, uint64_t *shift_idx); | |||||
| /// \brief check if column name is available | |||||
| ColumnCategory CheckColumnName(const std::string &column_name); | |||||
| /// \brief compress integer column | |||||
| static vector<uint8_t> CompressInt(const vector<uint8_t> &src_bytes, const IntegerType &int_type); | |||||
| /// \brief uncompress integer array column | |||||
| template <typename T> | |||||
| static MSRStatus UncompressInt(const uint64_t &column_id, std::unique_ptr<unsigned char[]> *data_ptr, | |||||
| const std::vector<uint8_t> &columns_blob, uint64_t *num_bytes, uint64_t shift_idx); | |||||
| /// \brief convert big-endian bytes to unsigned int | |||||
| /// \param bytes_array bytes array | |||||
| /// \param pos shift address in bytes array | |||||
| /// \param i_type integer type | |||||
| /// \return unsigned int | |||||
| static uint64_t BytesBigToUInt64(const std::vector<uint8_t> &bytes_array, const uint64_t &pos, | |||||
| const IntegerType &i_type); | |||||
| /// \brief convert unsigned int to big-endian bytes | |||||
| /// \param value integer value | |||||
| /// \param i_type integer type | |||||
| /// \return bytes | |||||
| static std::vector<uint8_t> UIntToBytesBig(uint64_t value, const IntegerType &i_type); | |||||
| /// \brief convert unsigned int to little-endian bytes | |||||
| /// \param value integer value | |||||
| /// \param i_type integer type | |||||
| /// \return bytes | |||||
| static std::vector<uint8_t> UIntToBytesLittle(uint64_t value, const IntegerType &i_type); | |||||
| /// \brief convert unsigned int to little-endian bytes | |||||
| /// \param bytes_array bytes array | |||||
| /// \param pos shift address in bytes array | |||||
| /// \param src_i_type source integer typ0e | |||||
| /// \param dst_i_type (output), destination integer type | |||||
| /// \return integer | |||||
| static int64_t BytesLittleToMinIntType(const std::vector<uint8_t> &bytes_array, const uint64_t &pos, | |||||
| const IntegerType &src_i_type, IntegerType *dst_i_type = nullptr); | |||||
| private: | |||||
| std::vector<std::string> column_name_; // column name list | |||||
| std::vector<ColumnDataType> column_data_type_; // column data type list | |||||
| std::vector<std::vector<int64_t>> column_shape_; // column shape list | |||||
| std::unordered_map<string, uint64_t> column_name_id_; // column name id map | |||||
| std::vector<std::string> blob_column_; // blob column list | |||||
| std::unordered_map<std::string, uint64_t> blob_column_id_; // blob column name id map | |||||
| bool has_compress_blob_; // if has compress blob | |||||
| uint64_t num_blob_column_; // number of blob columns | |||||
| }; | |||||
| } // namespace mindrecord | |||||
| } // namespace mindspore | |||||
| #endif // MINDRECORD_INCLUDE_SHARD_COLUMN_H_ | |||||
| @@ -118,8 +118,6 @@ class ShardHeader { | |||||
| void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; } | void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; } | ||||
| const string GetVersion() { return version_; } | |||||
| std::vector<std::string> SerializeHeader(); | std::vector<std::string> SerializeHeader(); | ||||
| MSRStatus PagesToFile(const std::string dump_file_name); | MSRStatus PagesToFile(const std::string dump_file_name); | ||||
| @@ -175,7 +173,6 @@ class ShardHeader { | |||||
| uint32_t shard_count_; | uint32_t shard_count_; | ||||
| uint64_t header_size_; | uint64_t header_size_; | ||||
| uint64_t page_size_; | uint64_t page_size_; | ||||
| string version_ = "2.0"; | |||||
| std::shared_ptr<Index> index_; | std::shared_ptr<Index> index_; | ||||
| std::vector<std::string> shard_addresses_; | std::vector<std::string> shard_addresses_; | ||||
| @@ -43,6 +43,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "mindrecord/include/common/shard_utils.h" | #include "mindrecord/include/common/shard_utils.h" | ||||
| #include "mindrecord/include/shard_category.h" | #include "mindrecord/include/shard_category.h" | ||||
| #include "mindrecord/include/shard_column.h" | |||||
| #include "mindrecord/include/shard_error.h" | #include "mindrecord/include/shard_error.h" | ||||
| #include "mindrecord/include/shard_index_generator.h" | #include "mindrecord/include/shard_index_generator.h" | ||||
| #include "mindrecord/include/shard_operator.h" | #include "mindrecord/include/shard_operator.h" | ||||
| @@ -111,6 +112,10 @@ class ShardReader { | |||||
| /// \return the metadata | /// \return the metadata | ||||
| std::shared_ptr<ShardHeader> GetShardHeader() const; | std::shared_ptr<ShardHeader> GetShardHeader() const; | ||||
| /// \brief aim to get columns context | |||||
| /// \return the columns | |||||
| std::shared_ptr<ShardColumn> get_shard_column() const; | |||||
| /// \brief get the number of shards | /// \brief get the number of shards | ||||
| /// \return # of shards | /// \return # of shards | ||||
| int GetShardCount() const; | int GetShardCount() const; | ||||
| @@ -185,7 +190,7 @@ class ShardReader { | |||||
| /// \brief return a batch, given that one is ready, python API | /// \brief return a batch, given that one is ready, python API | ||||
| /// \return a batch of images and image data | /// \return a batch of images and image data | ||||
| std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>> GetNextPy(); | |||||
| std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>> GetNextPy(); | |||||
| /// \brief get blob filed list | /// \brief get blob filed list | ||||
| /// \return blob field list | /// \return blob field list | ||||
| @@ -295,16 +300,18 @@ class ShardReader { | |||||
| /// \brief get number of classes | /// \brief get number of classes | ||||
| int64_t GetNumClasses(const std::string &category_field); | int64_t GetNumClasses(const std::string &category_field); | ||||
| /// \brief get meta of header | |||||
| std::pair<MSRStatus, std::vector<std::string>> GetMeta(const std::string &file_path, json &meta_data); | std::pair<MSRStatus, std::vector<std::string>> GetMeta(const std::string &file_path, json &meta_data); | ||||
| /// \brief get exactly blob fields data by indices | |||||
| std::vector<uint8_t> ExtractBlobFieldBySelectColumns(std::vector<uint8_t> &blob_fields_bytes, | |||||
| std::vector<uint32_t> &ordered_selected_columns_index); | |||||
| /// \brief extract uncompressed data based on column list | |||||
| std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> UnCompressBlob(const std::vector<uint8_t> &raw_blob_data); | |||||
| protected: | protected: | ||||
| uint64_t header_size_; // header size | uint64_t header_size_; // header size | ||||
| uint64_t page_size_; // page size | uint64_t page_size_; // page size | ||||
| int shard_count_; // number of shards | int shard_count_; // number of shards | ||||
| std::shared_ptr<ShardHeader> shard_header_; // shard header | std::shared_ptr<ShardHeader> shard_header_; // shard header | ||||
| std::shared_ptr<ShardColumn> shard_column_; // shard column | |||||
| std::vector<sqlite3 *> database_paths_; // sqlite handle list | std::vector<sqlite3 *> database_paths_; // sqlite handle list | ||||
| std::vector<string> file_paths_; // file paths | std::vector<string> file_paths_; // file paths | ||||
| @@ -36,6 +36,7 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "mindrecord/include/common/shard_utils.h" | #include "mindrecord/include/common/shard_utils.h" | ||||
| #include "mindrecord/include/shard_column.h" | |||||
| #include "mindrecord/include/shard_error.h" | #include "mindrecord/include/shard_error.h" | ||||
| #include "mindrecord/include/shard_header.h" | #include "mindrecord/include/shard_header.h" | ||||
| #include "mindrecord/include/shard_index.h" | #include "mindrecord/include/shard_index.h" | ||||
| @@ -242,7 +243,8 @@ class ShardWriter { | |||||
| std::vector<std::string> file_paths_; // file paths | std::vector<std::string> file_paths_; // file paths | ||||
| std::vector<std::shared_ptr<std::fstream>> file_streams_; // file handles | std::vector<std::shared_ptr<std::fstream>> file_streams_; // file handles | ||||
| std::shared_ptr<ShardHeader> shard_header_; // shard headers | |||||
| std::shared_ptr<ShardHeader> shard_header_; // shard header | |||||
| std::shared_ptr<ShardColumn> shard_column_; // shard columns | |||||
| std::map<uint64_t, std::map<int, std::string>> err_mg_; // used for storing error raw_data info | std::map<uint64_t, std::map<int, std::string>> err_mg_; // used for storing error raw_data info | ||||
| @@ -133,6 +133,12 @@ MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool loa | |||||
| shard_header_ = std::make_shared<ShardHeader>(sh); | shard_header_ = std::make_shared<ShardHeader>(sh); | ||||
| header_size_ = shard_header_->GetHeaderSize(); | header_size_ = shard_header_->GetHeaderSize(); | ||||
| page_size_ = shard_header_->GetPageSize(); | page_size_ = shard_header_->GetPageSize(); | ||||
| // version < 3.0 | |||||
| if (first_meta_data["version"] < kVersion) { | |||||
| shard_column_ = std::make_shared<ShardColumn>(shard_header_, false); | |||||
| } else { | |||||
| shard_column_ = std::make_shared<ShardColumn>(shard_header_, true); | |||||
| } | |||||
| num_rows_ = 0; | num_rows_ = 0; | ||||
| auto row_group_summary = ReadRowGroupSummary(); | auto row_group_summary = ReadRowGroupSummary(); | ||||
| for (const auto &rg : row_group_summary) { | for (const auto &rg : row_group_summary) { | ||||
| @@ -226,6 +232,8 @@ void ShardReader::Close() { | |||||
| std::shared_ptr<ShardHeader> ShardReader::GetShardHeader() const { return shard_header_; } | std::shared_ptr<ShardHeader> ShardReader::GetShardHeader() const { return shard_header_; } | ||||
| std::shared_ptr<ShardColumn> ShardReader::get_shard_column() const { return shard_column_; } | |||||
| int ShardReader::GetShardCount() const { return shard_header_->GetShardCount(); } | int ShardReader::GetShardCount() const { return shard_header_->GetShardCount(); } | ||||
| int ShardReader::GetNumRows() const { return num_rows_; } | int ShardReader::GetNumRows() const { return num_rows_; } | ||||
| @@ -1059,36 +1067,6 @@ MSRStatus ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, u | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| std::vector<uint8_t> ShardReader::ExtractBlobFieldBySelectColumns( | |||||
| std::vector<uint8_t> &blob_fields_bytes, std::vector<uint32_t> &ordered_selected_columns_index) { | |||||
| std::vector<uint8_t> exactly_blob_fields_bytes; | |||||
| auto uint64_from_bytes = [&](int64_t pos) { | |||||
| uint64_t result = 0; | |||||
| for (uint64_t n = 0; n < kInt64Len; n++) { | |||||
| result = (result << 8) + blob_fields_bytes[pos + n]; | |||||
| } | |||||
| return result; | |||||
| }; | |||||
| // get the exactly blob fields | |||||
| uint32_t current_index = 0; | |||||
| uint64_t current_offset = 0; | |||||
| uint64_t data_len = uint64_from_bytes(current_offset); | |||||
| while (current_offset < blob_fields_bytes.size()) { | |||||
| if (std::any_of(ordered_selected_columns_index.begin(), ordered_selected_columns_index.end(), | |||||
| [¤t_index](uint32_t &index) { return index == current_index; })) { | |||||
| exactly_blob_fields_bytes.insert(exactly_blob_fields_bytes.end(), blob_fields_bytes.begin() + current_offset, | |||||
| blob_fields_bytes.begin() + current_offset + kInt64Len + data_len); | |||||
| } | |||||
| current_index++; | |||||
| current_offset += kInt64Len + data_len; | |||||
| data_len = uint64_from_bytes(current_offset); | |||||
| } | |||||
| return exactly_blob_fields_bytes; | |||||
| } | |||||
| TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_id) { | TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_id) { | ||||
| // All tasks are done | // All tasks are done | ||||
| if (task_id >= static_cast<int>(tasks_.Size())) { | if (task_id >= static_cast<int>(tasks_.Size())) { | ||||
| @@ -1126,40 +1104,10 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_ | |||||
| return std::make_pair(FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>()); | return std::make_pair(FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>()); | ||||
| } | } | ||||
| // extract the exactly blob bytes by selected columns | |||||
| std::vector<uint8_t> images_with_exact_columns; | |||||
| if (selected_columns_.size() == 0) { | |||||
| images_with_exact_columns = images; | |||||
| } else { | |||||
| auto blob_fields = GetBlobFields(); | |||||
| std::vector<uint32_t> ordered_selected_columns_index; | |||||
| uint32_t index = 0; | |||||
| for (auto &blob_field : blob_fields.second) { | |||||
| for (auto &field : selected_columns_) { | |||||
| if (field.compare(blob_field) == 0) { | |||||
| ordered_selected_columns_index.push_back(index); | |||||
| break; | |||||
| } | |||||
| } | |||||
| index++; | |||||
| } | |||||
| if (ordered_selected_columns_index.size() != 0) { | |||||
| // extract the images | |||||
| if (blob_fields.second.size() == 1) { | |||||
| if (ordered_selected_columns_index.size() == 1) { | |||||
| images_with_exact_columns = images; | |||||
| } | |||||
| } else { | |||||
| images_with_exact_columns = ExtractBlobFieldBySelectColumns(images, ordered_selected_columns_index); | |||||
| } | |||||
| } | |||||
| } | |||||
| // Deliver batch data to output map | // Deliver batch data to output map | ||||
| std::vector<std::tuple<std::vector<uint8_t>, json>> batch; | std::vector<std::tuple<std::vector<uint8_t>, json>> batch; | ||||
| batch.emplace_back(std::move(images_with_exact_columns), std::move(std::get<2>(task))); | |||||
| batch.emplace_back(std::move(images), std::move(std::get<2>(task))); | |||||
| return std::make_pair(SUCCESS, std::move(batch)); | return std::make_pair(SUCCESS, std::move(batch)); | ||||
| } | } | ||||
| @@ -1369,16 +1317,41 @@ std::vector<std::tuple<std::vector<uint8_t>, json>> ShardReader::GetNextById(con | |||||
| return std::move(ret.second); | return std::move(ret.second); | ||||
| } | } | ||||
| std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>> ShardReader::GetNextPy() { | |||||
| std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> ShardReader::UnCompressBlob( | |||||
| const std::vector<uint8_t> &raw_blob_data) { | |||||
| auto loaded_columns = selected_columns_.size() == 0 ? shard_column_->GetColumnName() : selected_columns_; | |||||
| auto blob_fields = GetBlobFields().second; | |||||
| std::vector<std::vector<uint8_t>> blob_data; | |||||
| for (uint32_t i_col = 0; i_col < loaded_columns.size(); ++i_col) { | |||||
| if (std::find(blob_fields.begin(), blob_fields.end(), loaded_columns[i_col]) == blob_fields.end()) continue; | |||||
| const unsigned char *data = nullptr; | |||||
| std::unique_ptr<unsigned char[]> data_ptr; | |||||
| uint64_t n_bytes = 0; | |||||
| auto ret = shard_column_->GetColumnFromBlob(loaded_columns[i_col], raw_blob_data, &data, &data_ptr, &n_bytes); | |||||
| if (ret != SUCCESS) { | |||||
| MS_LOG(ERROR) << "Error when get data from blob, column name is " << loaded_columns[i_col] << "."; | |||||
| return {FAILED, std::vector<std::vector<uint8_t>>(blob_fields.size(), std::vector<uint8_t>())}; | |||||
| } | |||||
| if (data == nullptr) { | |||||
| data = reinterpret_cast<const unsigned char *>(data_ptr.get()); | |||||
| } | |||||
| std::vector<uint8_t> column(data, data + (n_bytes / sizeof(unsigned char))); | |||||
| blob_data.push_back(column); | |||||
| } | |||||
| return {SUCCESS, blob_data}; | |||||
| } | |||||
| std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>> ShardReader::GetNextPy() { | |||||
| auto res = GetNext(); | auto res = GetNext(); | ||||
| vector<std::tuple<std::vector<uint8_t>, pybind11::object>> jsonData; | |||||
| std::transform(res.begin(), res.end(), std::back_inserter(jsonData), | |||||
| [](const std::tuple<std::vector<uint8_t>, json> &item) { | |||||
| vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>> data; | |||||
| std::transform(res.begin(), res.end(), std::back_inserter(data), | |||||
| [this](const std::tuple<std::vector<uint8_t>, json> &item) { | |||||
| auto &j = std::get<1>(item); | auto &j = std::get<1>(item); | ||||
| pybind11::object obj = nlohmann::detail::FromJsonImpl(j); | pybind11::object obj = nlohmann::detail::FromJsonImpl(j); | ||||
| return std::make_tuple(std::get<0>(item), std::move(obj)); | |||||
| auto ret = UnCompressBlob(std::get<0>(item)); | |||||
| return std::make_tuple(ret.second, std::move(obj)); | |||||
| }); | }); | ||||
| return jsonData; | |||||
| return data; | |||||
| } | } | ||||
| void ShardReader::Reset() { | void ShardReader::Reset() { | ||||
| @@ -206,6 +206,7 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) { | |||||
| MS_LOG(ERROR) << "Open file failed"; | MS_LOG(ERROR) << "Open file failed"; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| shard_column_ = std::make_shared<ShardColumn>(shard_header_); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -271,6 +272,7 @@ MSRStatus ShardWriter::SetShardHeader(std::shared_ptr<ShardHeader> header_data) | |||||
| shard_header_ = header_data; | shard_header_ = header_data; | ||||
| shard_header_->SetHeaderSize(header_size_); | shard_header_->SetHeaderSize(header_size_); | ||||
| shard_header_->SetPageSize(page_size_); | shard_header_->SetPageSize(page_size_); | ||||
| shard_column_ = std::make_shared<ShardColumn>(shard_header_); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -608,6 +610,14 @@ MSRStatus ShardWriter::WriteRawDataPreCheck(std::map<uint64_t, std::vector<json> | |||||
| MS_LOG(ERROR) << "IO error / there is no free disk to be used"; | MS_LOG(ERROR) << "IO error / there is no free disk to be used"; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| // compress blob | |||||
| if (shard_column_->CheckCompressBlob()) { | |||||
| for (auto &blob : blob_data) { | |||||
| blob = shard_column_->CompressBlob(blob); | |||||
| } | |||||
| } | |||||
| // Add 4-bytes dummy blob data if no any blob fields | // Add 4-bytes dummy blob data if no any blob fields | ||||
| if (blob_data.size() == 0 && raw_data.size() > 0) { | if (blob_data.size() == 0 && raw_data.size() > 0) { | ||||
| blob_data = std::vector<std::vector<uint8_t>>(raw_data[0].size(), std::vector<uint8_t>(kUnsignedInt4, 0)); | blob_data = std::vector<std::vector<uint8_t>>(raw_data[0].size(), std::vector<uint8_t>(kUnsignedInt4, 0)); | ||||
| @@ -0,0 +1,473 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "mindrecord/include/shard_column.h" | |||||
| #include "common/utils.h" | |||||
| #include "mindrecord/include/common/shard_utils.h" | |||||
| #include "mindrecord/include/shard_error.h" | |||||
| namespace mindspore { | |||||
| namespace mindrecord { | |||||
| ShardColumn::ShardColumn(const std::shared_ptr<ShardHeader> &shard_header, bool compress_integer) { | |||||
| auto first_schema = shard_header->GetSchemas()[0]; | |||||
| auto schema = first_schema->GetSchema()["schema"]; | |||||
| bool has_integer_array = false; | |||||
| for (json::iterator it = schema.begin(); it != schema.end(); ++it) { | |||||
| const std::string &column_name = it.key(); | |||||
| column_name_.push_back(column_name); | |||||
| json it_value = it.value(); | |||||
| std::string str_type = it_value["type"]; | |||||
| column_data_type_.push_back(ColumnDataTypeMap.at(str_type)); | |||||
| if (it_value.find("shape") != it_value.end()) { | |||||
| std::vector<int64_t> vec(it_value["shape"].size()); | |||||
| std::copy(it_value["shape"].begin(), it_value["shape"].end(), vec.begin()); | |||||
| column_shape_.push_back(vec); | |||||
| if (str_type == "int32" || str_type == "int64") { | |||||
| has_integer_array = true; | |||||
| } | |||||
| } else { | |||||
| std::vector<int64_t> vec = {}; | |||||
| column_shape_.push_back(vec); | |||||
| } | |||||
| } | |||||
| for (uint64_t i = 0; i < column_name_.size(); i++) { | |||||
| column_name_id_[column_name_[i]] = i; | |||||
| } | |||||
| auto blob_fields = first_schema->GetBlobFields(); | |||||
| for (const auto &field : blob_fields) { | |||||
| blob_column_.push_back(field); | |||||
| } | |||||
| for (uint64_t i = 0; i < blob_column_.size(); i++) { | |||||
| blob_column_id_[blob_column_[i]] = i; | |||||
| } | |||||
| has_compress_blob_ = (compress_integer && has_integer_array); | |||||
| num_blob_column_ = blob_column_.size(); | |||||
| } | |||||
| MSRStatus ShardColumn::GetColumnValueByName(const std::string &column_name, const std::vector<uint8_t> &columns_blob, | |||||
| const json &columns_json, const unsigned char **data, | |||||
| std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes, | |||||
| ColumnDataType *column_data_type, uint64_t *column_data_type_size, | |||||
| std::vector<int64_t> *column_shape) { | |||||
| // Skip if column not found | |||||
| auto column_category = CheckColumnName(column_name); | |||||
| if (column_category == ColumnNotFound) { | |||||
| return FAILED; | |||||
| } | |||||
| // Get data type and size | |||||
| auto column_id = column_name_id_[column_name]; | |||||
| *column_data_type = column_data_type_[column_id]; | |||||
| *column_data_type_size = ColumnDataTypeSize[*column_data_type]; | |||||
| *column_shape = column_shape_[column_id]; | |||||
| // Retrieve value from json | |||||
| if (column_category == ColumnInRaw) { | |||||
| if (GetColumnFromJson(column_name, columns_json, data_ptr, n_bytes) == FAILED) { | |||||
| MS_LOG(ERROR) << "Error when get data from json, column name is " << column_name << "."; | |||||
| return FAILED; | |||||
| } | |||||
| *data = reinterpret_cast<const unsigned char *>(data_ptr->get()); | |||||
| return SUCCESS; | |||||
| } | |||||
| // Retrieve value from blob | |||||
| if (GetColumnFromBlob(column_name, columns_blob, data, data_ptr, n_bytes) == FAILED) { | |||||
| MS_LOG(ERROR) << "Error when get data from blob, column name is " << column_name << "."; | |||||
| return FAILED; | |||||
| } | |||||
| if (*data == nullptr) { | |||||
| *data = reinterpret_cast<const unsigned char *>(data_ptr->get()); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| MSRStatus ShardColumn::GetColumnFromJson(const std::string &column_name, const json &columns_json, | |||||
| std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes) { | |||||
| auto column_id = column_name_id_[column_name]; | |||||
| auto column_data_type = column_data_type_[column_id]; | |||||
| // Initialize num bytes | |||||
| *n_bytes = ColumnDataTypeSize[column_data_type]; | |||||
| auto json_column_value = columns_json[column_name]; | |||||
| switch (column_data_type) { | |||||
| case ColumnFloat32: { | |||||
| return GetFloat<float>(data_ptr, json_column_value, false); | |||||
| } | |||||
| case ColumnFloat64: { | |||||
| return GetFloat<double>(data_ptr, json_column_value, true); | |||||
| } | |||||
| case ColumnInt32: { | |||||
| return GetInt<int32_t>(data_ptr, json_column_value); | |||||
| } | |||||
| case ColumnInt64: { | |||||
| return GetInt<int64_t>(data_ptr, json_column_value); | |||||
| } | |||||
| default: { | |||||
| // Convert string to c_str | |||||
| std::string tmp_string = json_column_value; | |||||
| *n_bytes = tmp_string.size(); | |||||
| auto data = reinterpret_cast<const unsigned char *>(common::SafeCStr(tmp_string)); | |||||
| *data_ptr = std::make_unique<unsigned char[]>(*n_bytes); | |||||
| for (uint32_t i = 0; i < *n_bytes; i++) { | |||||
| (*data_ptr)[i] = *(data + i); | |||||
| } | |||||
| break; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| template <typename T> | |||||
| MSRStatus ShardColumn::GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value, | |||||
| bool use_double) { | |||||
| std::unique_ptr<T[]> array_data = std::make_unique<T[]>(1); | |||||
| if (!json_column_value.is_string() && !json_column_value.is_number()) { | |||||
| MS_LOG(ERROR) << "Conversion to float failed (" << json_column_value << ")."; | |||||
| return FAILED; | |||||
| } | |||||
| if (json_column_value.is_number()) { | |||||
| array_data[0] = json_column_value; | |||||
| } else { | |||||
| // Convert string to float | |||||
| try { | |||||
| if (use_double) { | |||||
| array_data[0] = json_column_value.get<double>(); | |||||
| } else { | |||||
| array_data[0] = json_column_value.get<float>(); | |||||
| } | |||||
| } catch (json::exception &e) { | |||||
| MS_LOG(ERROR) << "Conversion to float failed (" << json_column_value << ")."; | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| auto data = reinterpret_cast<const unsigned char *>(array_data.get()); | |||||
| *data_ptr = std::make_unique<unsigned char[]>(sizeof(T)); | |||||
| for (uint32_t i = 0; i < sizeof(T); i++) { | |||||
| (*data_ptr)[i] = *(data + i); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| template <typename T> | |||||
| MSRStatus ShardColumn::GetInt(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value) { | |||||
| std::unique_ptr<T[]> array_data = std::make_unique<T[]>(1); | |||||
| int64_t temp_value; | |||||
| bool less_than_zero = false; | |||||
| if (json_column_value.is_number_integer()) { | |||||
| const json json_zero = 0; | |||||
| if (json_column_value < json_zero) less_than_zero = true; | |||||
| temp_value = json_column_value; | |||||
| } else if (json_column_value.is_string()) { | |||||
| std::string string_value = json_column_value; | |||||
| if (!string_value.empty() && string_value[0] == '-') { | |||||
| try { | |||||
| temp_value = std::stoll(string_value); | |||||
| less_than_zero = true; | |||||
| } catch (std::invalid_argument &e) { | |||||
| MS_LOG(ERROR) << "Conversion to int failed, invalid argument."; | |||||
| return FAILED; | |||||
| } catch (std::out_of_range &e) { | |||||
| MS_LOG(ERROR) << "Conversion to int failed, out of range."; | |||||
| return FAILED; | |||||
| } | |||||
| } else { | |||||
| try { | |||||
| temp_value = static_cast<int64_t>(std::stoull(string_value)); | |||||
| } catch (std::invalid_argument &e) { | |||||
| MS_LOG(ERROR) << "Conversion to int failed, invalid argument."; | |||||
| return FAILED; | |||||
| } catch (std::out_of_range &e) { | |||||
| MS_LOG(ERROR) << "Conversion to int failed, out of range."; | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Conversion to int failed."; | |||||
| return FAILED; | |||||
| } | |||||
| if ((less_than_zero && temp_value < static_cast<int64_t>(std::numeric_limits<T>::min())) || | |||||
| (!less_than_zero && static_cast<uint64_t>(temp_value) > static_cast<uint64_t>(std::numeric_limits<T>::max()))) { | |||||
| MS_LOG(ERROR) << "Conversion to int failed. Out of range"; | |||||
| return FAILED; | |||||
| } | |||||
| array_data[0] = static_cast<T>(temp_value); | |||||
| auto data = reinterpret_cast<const unsigned char *>(array_data.get()); | |||||
| *data_ptr = std::make_unique<unsigned char[]>(sizeof(T)); | |||||
| for (uint32_t i = 0; i < sizeof(T); i++) { | |||||
| (*data_ptr)[i] = *(data + i); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| MSRStatus ShardColumn::GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob, | |||||
| const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr, | |||||
| uint64_t *n_bytes) { | |||||
| uint64_t offset_address = 0; | |||||
| auto column_id = column_name_id_[column_name]; | |||||
| if (GetColumnAddressInBlock(column_id, columns_blob, n_bytes, &offset_address) == FAILED) { | |||||
| return FAILED; | |||||
| } | |||||
| auto column_data_type = column_data_type_[column_id]; | |||||
| if (has_compress_blob_ && column_data_type == ColumnInt32) { | |||||
| if (UncompressInt<int32_t>(column_id, data_ptr, columns_blob, n_bytes, offset_address) == FAILED) { | |||||
| return FAILED; | |||||
| } | |||||
| } else if (has_compress_blob_ && column_data_type == ColumnInt64) { | |||||
| if (UncompressInt<int64_t>(column_id, data_ptr, columns_blob, n_bytes, offset_address) == FAILED) { | |||||
| return FAILED; | |||||
| } | |||||
| } else { | |||||
| *data = reinterpret_cast<const unsigned char *>(&(columns_blob[offset_address])); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| ColumnCategory ShardColumn::CheckColumnName(const std::string &column_name) { | |||||
| auto it_column = column_name_id_.find(column_name); | |||||
| if (it_column == column_name_id_.end()) { | |||||
| return ColumnNotFound; | |||||
| } | |||||
| auto it_blob = blob_column_id_.find(column_name); | |||||
| return it_blob == blob_column_id_.end() ? ColumnInRaw : ColumnInBlob; | |||||
| } | |||||
| std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob) { | |||||
| // Skip if no compress columns | |||||
| if (!CheckCompressBlob()) return blob; | |||||
| std::vector<uint8_t> dst_blob; | |||||
| uint64_t i_src = 0; | |||||
| for (int64_t i = 0; i < num_blob_column_; i++) { | |||||
| // Get column data type | |||||
| auto src_data_type = column_data_type_[column_name_id_[blob_column_[i]]]; | |||||
| auto int_type = src_data_type == ColumnInt32 ? kInt32Type : kInt64Type; | |||||
| // Compress and return is blob has 1 column only | |||||
| if (num_blob_column_ == 1) { | |||||
| return CompressInt(blob, int_type); | |||||
| } | |||||
| // Just copy and continue if column dat type is not int32/int64 | |||||
| uint64_t num_bytes = BytesBigToUInt64(blob, i_src, kInt64Type); | |||||
| if (src_data_type != ColumnInt32 && src_data_type != ColumnInt64) { | |||||
| dst_blob.insert(dst_blob.end(), blob.begin() + i_src, blob.begin() + i_src + kInt64Len + num_bytes); | |||||
| i_src += kInt64Len + num_bytes; | |||||
| continue; | |||||
| } | |||||
| // Get column slice in source blob | |||||
| std::vector<uint8_t> blob_slice(blob.begin() + i_src + kInt64Len, blob.begin() + i_src + kInt64Len + num_bytes); | |||||
| // Compress column | |||||
| auto dst_blob_slice = CompressInt(blob_slice, int_type); | |||||
| // Get new column size | |||||
| auto new_blob_size = UIntToBytesBig(dst_blob_slice.size(), kInt64Type); | |||||
| // Append new colmn size | |||||
| dst_blob.insert(dst_blob.end(), new_blob_size.begin(), new_blob_size.end()); | |||||
| // Append new colmn data | |||||
| dst_blob.insert(dst_blob.end(), dst_blob_slice.begin(), dst_blob_slice.end()); | |||||
| i_src += kInt64Len + num_bytes; | |||||
| } | |||||
| MS_LOG(DEBUG) << "Compress all blob from " << blob.size() << " to " << dst_blob.size() << "."; | |||||
| return dst_blob; | |||||
| } | |||||
| vector<uint8_t> ShardColumn::CompressInt(const vector<uint8_t> &src_bytes, const IntegerType &int_type) { | |||||
| uint64_t i_size = kUnsignedOne << int_type; | |||||
| // Get number of elements | |||||
| uint64_t src_n_int = src_bytes.size() / i_size; | |||||
| // Calculate bitmap size (bytes) | |||||
| uint64_t bitmap_size = (src_n_int + kNumDataOfByte - 1) / kNumDataOfByte; | |||||
| // Initilize destination blob, more space than needed, will be resized | |||||
| vector<uint8_t> dst_bytes(kBytesOfColumnLen + bitmap_size + src_bytes.size(), 0); | |||||
| // Write number of elements to destination blob | |||||
| vector<uint8_t> size_by_bytes = UIntToBytesBig(src_n_int, kInt32Type); | |||||
| for (uint64_t n = 0; n < kBytesOfColumnLen; n++) { | |||||
| dst_bytes[n] = size_by_bytes[n]; | |||||
| } | |||||
| // Write compressed int | |||||
| uint64_t i_dst = kBytesOfColumnLen + bitmap_size; | |||||
| for (uint64_t i = 0; i < src_n_int; i++) { | |||||
| // Initialize destination data type | |||||
| IntegerType dst_int_type = kInt8Type; | |||||
| // Shift to next int position | |||||
| uint64_t pos = i * (kUnsignedOne << int_type); | |||||
| // Narrow down this int | |||||
| int64_t i_n = BytesLittleToMinIntType(src_bytes, pos, int_type, &dst_int_type); | |||||
| // Write this int to destination blob | |||||
| uint64_t u_n = *reinterpret_cast<uint64_t *>(&i_n); | |||||
| auto temp_bytes = UIntToBytesLittle(u_n, dst_int_type); | |||||
| for (uint64_t j = 0; j < (kUnsignedOne << dst_int_type); j++) { | |||||
| dst_bytes[i_dst++] = temp_bytes[j]; | |||||
| } | |||||
| // Update date type in bit map | |||||
| dst_bytes[i / kNumDataOfByte + kBytesOfColumnLen] |= | |||||
| (dst_int_type << (kDataTypeBits * (kNumDataOfByte - kUnsignedOne - (i % kNumDataOfByte)))); | |||||
| } | |||||
| // Resize destination blob | |||||
| dst_bytes.resize(i_dst); | |||||
| MS_LOG(DEBUG) << "Compress blob field from " << src_bytes.size() << " to " << dst_bytes.size() << "."; | |||||
| return dst_bytes; | |||||
| } | |||||
| MSRStatus ShardColumn::GetColumnAddressInBlock(const uint64_t &column_id, const std::vector<uint8_t> &columns_blob, | |||||
| uint64_t *num_bytes, uint64_t *shift_idx) { | |||||
| if (num_blob_column_ == 1) { | |||||
| *num_bytes = columns_blob.size(); | |||||
| *shift_idx = 0; | |||||
| return SUCCESS; | |||||
| } | |||||
| auto blob_id = blob_column_id_[column_name_[column_id]]; | |||||
| for (int32_t i = 0; i < blob_id; i++) { | |||||
| *shift_idx += kInt64Len + BytesBigToUInt64(columns_blob, *shift_idx, kInt64Type); | |||||
| } | |||||
| *num_bytes = BytesBigToUInt64(columns_blob, *shift_idx, kInt64Type); | |||||
| (*shift_idx) += kInt64Len; | |||||
| return SUCCESS; | |||||
| } | |||||
| template <typename T> | |||||
| MSRStatus ShardColumn::UncompressInt(const uint64_t &column_id, std::unique_ptr<unsigned char[]> *data_ptr, | |||||
| const std::vector<uint8_t> &columns_blob, uint64_t *num_bytes, | |||||
| uint64_t shift_idx) { | |||||
| auto num_elements = BytesBigToUInt64(columns_blob, shift_idx, kInt32Type); | |||||
| *num_bytes = sizeof(T) * num_elements; | |||||
| // Parse integer array | |||||
| uint64_t i_source = shift_idx + kBytesOfColumnLen + (num_elements + kNumDataOfByte - 1) / kNumDataOfByte; | |||||
| auto array_data = std::make_unique<T[]>(num_elements); | |||||
| for (uint64_t i = 0; i < num_elements; i++) { | |||||
| uint8_t iBitMap = columns_blob[shift_idx + kBytesOfColumnLen + i / kNumDataOfByte]; | |||||
| uint64_t i_type = (iBitMap >> ((kNumDataOfByte - 1 - (i % kNumDataOfByte)) * kDataTypeBits)) & kDataTypeBitMask; | |||||
| auto mr_int_type = static_cast<IntegerType>(i_type); | |||||
| int64_t i64 = BytesLittleToMinIntType(columns_blob, i_source, mr_int_type); | |||||
| i_source += (kUnsignedOne << i_type); | |||||
| array_data[i] = static_cast<T>(i64); | |||||
| } | |||||
| auto data = reinterpret_cast<const unsigned char *>(array_data.get()); | |||||
| *data_ptr = std::make_unique<unsigned char[]>(*num_bytes); | |||||
| memcpy(data_ptr->get(), data, *num_bytes); | |||||
| return SUCCESS; | |||||
| } | |||||
| uint64_t ShardColumn::BytesBigToUInt64(const std::vector<uint8_t> &bytes_array, const uint64_t &pos, | |||||
| const IntegerType &i_type) { | |||||
| uint64_t result = 0; | |||||
| for (uint64_t i = 0; i < (kUnsignedOne << i_type); i++) { | |||||
| result = (result << kBitsOfByte) + bytes_array[pos + i]; | |||||
| } | |||||
| return result; | |||||
| } | |||||
| std::vector<uint8_t> ShardColumn::UIntToBytesBig(uint64_t value, const IntegerType &i_type) { | |||||
| uint64_t n_bytes = kUnsignedOne << i_type; | |||||
| std::vector<uint8_t> result(n_bytes, 0); | |||||
| for (uint64_t i = 0; i < n_bytes; i++) { | |||||
| result[n_bytes - 1 - i] = value & std::numeric_limits<uint8_t>::max(); | |||||
| value >>= kBitsOfByte; | |||||
| } | |||||
| return result; | |||||
| } | |||||
| std::vector<uint8_t> ShardColumn::UIntToBytesLittle(uint64_t value, const IntegerType &i_type) { | |||||
| uint64_t n_bytes = kUnsignedOne << i_type; | |||||
| std::vector<uint8_t> result(n_bytes, 0); | |||||
| for (uint64_t i = 0; i < n_bytes; i++) { | |||||
| result[i] = value & std::numeric_limits<uint8_t>::max(); | |||||
| value >>= kBitsOfByte; | |||||
| } | |||||
| return result; | |||||
| } | |||||
| int64_t ShardColumn::BytesLittleToMinIntType(const std::vector<uint8_t> &bytes_array, const uint64_t &pos, | |||||
| const IntegerType &src_i_type, IntegerType *dst_i_type) { | |||||
| uint64_t u_temp = 0; | |||||
| for (uint64_t i = 0; i < (kUnsignedOne << src_i_type); i++) { | |||||
| u_temp = (u_temp << kBitsOfByte) + bytes_array[pos + (kUnsignedOne << src_i_type) - kUnsignedOne - i]; | |||||
| } | |||||
| int64_t i_out; | |||||
| switch (src_i_type) { | |||||
| case kInt8Type: { | |||||
| i_out = (int8_t)(u_temp & std::numeric_limits<uint8_t>::max()); | |||||
| break; | |||||
| } | |||||
| case kInt16Type: { | |||||
| i_out = (int16_t)(u_temp & std::numeric_limits<uint16_t>::max()); | |||||
| break; | |||||
| } | |||||
| case kInt32Type: { | |||||
| i_out = (int32_t)(u_temp & std::numeric_limits<uint32_t>::max()); | |||||
| break; | |||||
| } | |||||
| case kInt64Type: { | |||||
| i_out = (int64_t)(u_temp & std::numeric_limits<uint64_t>::max()); | |||||
| break; | |||||
| } | |||||
| default: { | |||||
| i_out = 0; | |||||
| } | |||||
| } | |||||
| if (!dst_i_type) { | |||||
| return i_out; | |||||
| } | |||||
| if (i_out >= static_cast<int64_t>(std::numeric_limits<int8_t>::min()) && | |||||
| i_out <= static_cast<int64_t>(std::numeric_limits<int8_t>::max())) { | |||||
| *dst_i_type = kInt8Type; | |||||
| } else if (i_out >= static_cast<int64_t>(std::numeric_limits<int16_t>::min()) && | |||||
| i_out <= static_cast<int64_t>(std::numeric_limits<int16_t>::max())) { | |||||
| *dst_i_type = kInt16Type; | |||||
| } else if (i_out >= static_cast<int64_t>(std::numeric_limits<int32_t>::min()) && | |||||
| i_out <= static_cast<int64_t>(std::numeric_limits<int32_t>::max())) { | |||||
| *dst_i_type = kInt32Type; | |||||
| } else { | |||||
| *dst_i_type = kInt64Type; | |||||
| } | |||||
| return i_out; | |||||
| } | |||||
| } // namespace mindrecord | |||||
| } // namespace mindspore | |||||
| @@ -201,9 +201,9 @@ void ShardHeader::GetHeadersOneTask(int start, int end, std::vector<json> &heade | |||||
| json header; | json header; | ||||
| header = ret.second; | header = ret.second; | ||||
| header["shard_addresses"] = realAddresses; | header["shard_addresses"] = realAddresses; | ||||
| if (header["version"] != version_) { | |||||
| if (std::find(kSupportedVersion.begin(), kSupportedVersion.end(), header["version"]) == kSupportedVersion.end()) { | |||||
| MS_LOG(ERROR) << "Version wrong, file version is: " << header["version"].dump() | MS_LOG(ERROR) << "Version wrong, file version is: " << header["version"].dump() | ||||
| << ", lib version is: " << version_; | |||||
| << ", lib version is: " << kVersion; | |||||
| thread_status = true; | thread_status = true; | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -339,7 +339,7 @@ std::vector<std::string> ShardHeader::SerializeHeader() { | |||||
| s += "\"shard_addresses\":" + address + ","; | s += "\"shard_addresses\":" + address + ","; | ||||
| s += "\"shard_id\":" + std::to_string(shardId) + ","; | s += "\"shard_id\":" + std::to_string(shardId) + ","; | ||||
| s += "\"statistics\":" + stats + ","; | s += "\"statistics\":" + stats + ","; | ||||
| s += "\"version\":\"" + version_ + "\""; | |||||
| s += "\"version\":\"" + std::string(kVersion) + "\""; | |||||
| s += "}"; | s += "}"; | ||||
| header.emplace_back(s); | header.emplace_back(s); | ||||
| } | } | ||||
| @@ -97,16 +97,13 @@ def populate_data(raw, blob, columns, blob_fields, schema): | |||||
| if not blob_fields: | if not blob_fields: | ||||
| return raw | return raw | ||||
| # Get the order preserving sequence of columns in blob | |||||
| ordered_columns = [] | |||||
| loaded_columns = [] | |||||
| if columns: | if columns: | ||||
| for blob_field in blob_fields: | |||||
| if blob_field in columns: | |||||
| ordered_columns.append(blob_field) | |||||
| for column in columns: | |||||
| if column in blob_fields: | |||||
| loaded_columns.append(column) | |||||
| else: | else: | ||||
| ordered_columns = blob_fields | |||||
| blob_bytes = bytes(blob) | |||||
| loaded_columns = blob_fields | |||||
| def _render_raw(field, blob_data): | def _render_raw(field, blob_data): | ||||
| data_type = schema[field]['type'] | data_type = schema[field]['type'] | ||||
| @@ -119,24 +116,6 @@ def populate_data(raw, blob, columns, blob_fields, schema): | |||||
| else: | else: | ||||
| raw[field] = blob_data | raw[field] = blob_data | ||||
| if len(blob_fields) == 1: | |||||
| if len(ordered_columns) == 1: | |||||
| _render_raw(blob_fields[0], blob_bytes) | |||||
| return raw | |||||
| return raw | |||||
| def _int_from_bytes(xbytes: bytes) -> int: | |||||
| return int.from_bytes(xbytes, 'big') | |||||
| def _blob_at_position(pos): | |||||
| start = 0 | |||||
| for _ in range(pos): | |||||
| n_bytes = _int_from_bytes(blob_bytes[start : start + 8]) | |||||
| start += 8 + n_bytes | |||||
| n_bytes = _int_from_bytes(blob_bytes[start : start + 8]) | |||||
| start += 8 | |||||
| return blob_bytes[start : start + n_bytes] | |||||
| for i, blob_field in enumerate(ordered_columns): | |||||
| _render_raw(blob_field, _blob_at_position(i)) | |||||
| for i, blob_field in enumerate(loaded_columns): | |||||
| _render_raw(blob_field, bytes(blob[i])) | |||||
| return raw | return raw | ||||
| @@ -35,6 +35,7 @@ CV1_FILE_NAME = "../data/mindrecord/imagenet1.mindrecord" | |||||
| CV2_FILE_NAME = "../data/mindrecord/imagenet2.mindrecord" | CV2_FILE_NAME = "../data/mindrecord/imagenet2.mindrecord" | ||||
| CV_DIR_NAME = "../data/mindrecord/testImageNetData" | CV_DIR_NAME = "../data/mindrecord/testImageNetData" | ||||
| NLP_FILE_NAME = "../data/mindrecord/aclImdb.mindrecord" | NLP_FILE_NAME = "../data/mindrecord/aclImdb.mindrecord" | ||||
| OLD_NLP_FILE_NAME = "../data/mindrecord/testOldVersion/aclImdb.mindrecord" | |||||
| NLP_FILE_POS = "../data/mindrecord/testAclImdbData/pos" | NLP_FILE_POS = "../data/mindrecord/testAclImdbData/pos" | ||||
| NLP_FILE_VOCAB = "../data/mindrecord/testAclImdbData/vocab.txt" | NLP_FILE_VOCAB = "../data/mindrecord/testAclImdbData/vocab.txt" | ||||
| @@ -46,7 +47,8 @@ def add_and_remove_cv_file(): | |||||
| for x in range(FILES_NUM)] | for x in range(FILES_NUM)] | ||||
| for x in paths: | for x in paths: | ||||
| os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None | os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None | ||||
| os.remove("{}.db".format(x)) if os.path.exists("{}.db".format(x)) else None | |||||
| os.remove("{}.db".format(x)) if os.path.exists( | |||||
| "{}.db".format(x)) else None | |||||
| writer = FileWriter(CV_FILE_NAME, FILES_NUM) | writer = FileWriter(CV_FILE_NAME, FILES_NUM) | ||||
| data = get_data(CV_DIR_NAME) | data = get_data(CV_DIR_NAME) | ||||
| cv_schema_json = {"id": {"type": "int32"}, | cv_schema_json = {"id": {"type": "int32"}, | ||||
| @@ -96,13 +98,105 @@ def add_and_remove_nlp_file(): | |||||
| os.remove("{}.db".format(x)) | os.remove("{}.db".format(x)) | ||||
| @pytest.fixture | |||||
| def add_and_remove_nlp_compress_file(): | |||||
| """add/remove nlp file""" | |||||
| paths = ["{}{}".format(NLP_FILE_NAME, str(x).rjust(1, '0')) | |||||
| for x in range(FILES_NUM)] | |||||
| for x in paths: | |||||
| if os.path.exists("{}".format(x)): | |||||
| os.remove("{}".format(x)) | |||||
| if os.path.exists("{}.db".format(x)): | |||||
| os.remove("{}.db".format(x)) | |||||
| writer = FileWriter(NLP_FILE_NAME, FILES_NUM) | |||||
| data = [] | |||||
| for row_id in range(16): | |||||
| data.append({ | |||||
| "label": row_id, | |||||
| "array_a": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, | |||||
| 255, 256, -32768, 32767, -32769, 32768, -2147483648, | |||||
| 2147483647], dtype=np.int32), [-1]), | |||||
| "array_b": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, 255, | |||||
| 256, -32768, 32767, -32769, 32768, -2147483648, 2147483647, -2147483649, 2147483649, -922337036854775808, 9223372036854775807]), [1, -1]), | |||||
| "array_c": str.encode("nlp data"), | |||||
| "array_d": np.reshape(np.array([[-10, -127], [10, 127]]), [2, -1]) | |||||
| }) | |||||
| nlp_schema_json = {"label": {"type": "int32"}, | |||||
| "array_a": {"type": "int32", | |||||
| "shape": [-1]}, | |||||
| "array_b": {"type": "int64", | |||||
| "shape": [1, -1]}, | |||||
| "array_c": {"type": "bytes"}, | |||||
| "array_d": {"type": "int64", | |||||
| "shape": [2, -1]} | |||||
| } | |||||
| writer.set_header_size(1 << 14) | |||||
| writer.set_page_size(1 << 15) | |||||
| writer.add_schema(nlp_schema_json, "nlp_schema") | |||||
| writer.write_raw_data(data) | |||||
| writer.commit() | |||||
| yield "yield_nlp_data" | |||||
| for x in paths: | |||||
| os.remove("{}".format(x)) | |||||
| os.remove("{}.db".format(x)) | |||||
| def test_nlp_compress_data(add_and_remove_nlp_compress_file): | |||||
| """tutorial for nlp minderdataset.""" | |||||
| data = [] | |||||
| for row_id in range(16): | |||||
| data.append({ | |||||
| "label": row_id, | |||||
| "array_a": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, | |||||
| 255, 256, -32768, 32767, -32769, 32768, -2147483648, | |||||
| 2147483647], dtype=np.int32), [-1]), | |||||
| "array_b": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, 255, | |||||
| 256, -32768, 32767, -32769, 32768, -2147483648, 2147483647, -2147483649, 2147483649, -922337036854775808, 9223372036854775807]), [1, -1]), | |||||
| "array_c": str.encode("nlp data"), | |||||
| "array_d": np.reshape(np.array([[-10, -127], [10, 127]]), [2, -1]) | |||||
| }) | |||||
| num_readers = 1 | |||||
| data_set = ds.MindDataset( | |||||
| NLP_FILE_NAME + "0", None, num_readers, shuffle=False) | |||||
| assert data_set.get_dataset_size() == 16 | |||||
| num_iter = 0 | |||||
| for x, item in zip(data, data_set.create_dict_iterator()): | |||||
| assert (item["array_a"] == x["array_a"]).all() | |||||
| assert (item["array_b"] == x["array_b"]).all() | |||||
| assert item["array_c"].tobytes() == x["array_c"] | |||||
| assert (item["array_d"] == x["array_d"]).all() | |||||
| assert item["label"] == x["label"] | |||||
| num_iter += 1 | |||||
| assert num_iter == 16 | |||||
| def test_nlp_compress_data_old_version(add_and_remove_nlp_compress_file): | |||||
| """tutorial for nlp minderdataset.""" | |||||
| num_readers = 1 | |||||
| data_set = ds.MindDataset( | |||||
| NLP_FILE_NAME + "0", None, num_readers, shuffle=False) | |||||
| old_data_set = ds.MindDataset( | |||||
| OLD_NLP_FILE_NAME + "0", None, num_readers, shuffle=False) | |||||
| assert old_data_set.get_dataset_size() == 16 | |||||
| num_iter = 0 | |||||
| for x, item in zip(old_data_set.create_dict_iterator(), data_set.create_dict_iterator()): | |||||
| assert (item["array_a"] == x["array_a"]).all() | |||||
| assert (item["array_b"] == x["array_b"]).all() | |||||
| assert (item["array_c"] == x["array_c"]).all() | |||||
| assert (item["array_d"] == x["array_d"]).all() | |||||
| assert item["label"] == x["label"] | |||||
| num_iter += 1 | |||||
| assert num_iter == 16 | |||||
| def test_cv_minddataset_writer_tutorial(): | def test_cv_minddataset_writer_tutorial(): | ||||
| """tutorial for cv dataset writer.""" | """tutorial for cv dataset writer.""" | ||||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | ||||
| for x in range(FILES_NUM)] | for x in range(FILES_NUM)] | ||||
| for x in paths: | for x in paths: | ||||
| os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None | os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None | ||||
| os.remove("{}.db".format(x)) if os.path.exists("{}.db".format(x)) else None | |||||
| os.remove("{}.db".format(x)) if os.path.exists( | |||||
| "{}.db".format(x)) else None | |||||
| writer = FileWriter(CV_FILE_NAME, FILES_NUM) | writer = FileWriter(CV_FILE_NAME, FILES_NUM) | ||||
| data = get_data(CV_DIR_NAME) | data = get_data(CV_DIR_NAME) | ||||
| cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"}, | cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"}, | ||||
| @@ -127,8 +221,10 @@ def test_cv_minddataset_partition_tutorial(add_and_remove_cv_file): | |||||
| num_shards=num_shards, shard_id=partition_id) | num_shards=num_shards, shard_id=partition_id) | ||||
| num_iter = 0 | num_iter = 0 | ||||
| for item in data_set.create_dict_iterator(): | for item in data_set.create_dict_iterator(): | ||||
| logger.info("-------------- partition : {} ------------------------".format(partition_id)) | |||||
| logger.info("-------------- item[label]: {} -----------------------".format(item["label"])) | |||||
| logger.info( | |||||
| "-------------- partition : {} ------------------------".format(partition_id)) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} -----------------------".format(item["label"])) | |||||
| num_iter += 1 | num_iter += 1 | ||||
| return num_iter | return num_iter | ||||
| @@ -147,9 +243,12 @@ def test_cv_minddataset_dataset_size(add_and_remove_cv_file): | |||||
| data_set = data_set.repeat(repeat_num) | data_set = data_set.repeat(repeat_num) | ||||
| num_iter = 0 | num_iter = 0 | ||||
| for item in data_set.create_dict_iterator(): | for item in data_set.create_dict_iterator(): | ||||
| logger.info("-------------- get dataset size {} -----------------".format(num_iter)) | |||||
| logger.info("-------------- item[label]: {} ---------------------".format(item["label"])) | |||||
| logger.info("-------------- item[data]: {} ----------------------".format(item["data"])) | |||||
| logger.info( | |||||
| "-------------- get dataset size {} -----------------".format(num_iter)) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} ---------------------".format(item["label"])) | |||||
| logger.info( | |||||
| "-------------- item[data]: {} ----------------------".format(item["data"])) | |||||
| num_iter += 1 | num_iter += 1 | ||||
| assert num_iter == 20 | assert num_iter == 20 | ||||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | ||||
| @@ -163,17 +262,22 @@ def test_cv_minddataset_repeat_reshuffle(add_and_remove_cv_file): | |||||
| num_readers = 4 | num_readers = 4 | ||||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers) | data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers) | ||||
| decode_op = vision.Decode() | decode_op = vision.Decode() | ||||
| data_set = data_set.map(input_columns=["data"], operations=decode_op, num_parallel_workers=2) | |||||
| data_set = data_set.map( | |||||
| input_columns=["data"], operations=decode_op, num_parallel_workers=2) | |||||
| resize_op = vision.Resize((32, 32), interpolation=Inter.LINEAR) | resize_op = vision.Resize((32, 32), interpolation=Inter.LINEAR) | ||||
| data_set = data_set.map(input_columns="data", operations=resize_op, num_parallel_workers=2) | |||||
| data_set = data_set.map(input_columns="data", | |||||
| operations=resize_op, num_parallel_workers=2) | |||||
| data_set = data_set.batch(2) | data_set = data_set.batch(2) | ||||
| data_set = data_set.repeat(2) | data_set = data_set.repeat(2) | ||||
| num_iter = 0 | num_iter = 0 | ||||
| labels = [] | labels = [] | ||||
| for item in data_set.create_dict_iterator(): | for item in data_set.create_dict_iterator(): | ||||
| logger.info("-------------- get dataset size {} -----------------".format(num_iter)) | |||||
| logger.info("-------------- item[label]: {} ---------------------".format(item["label"])) | |||||
| logger.info("-------------- item[data]: {} ----------------------".format(item["data"])) | |||||
| logger.info( | |||||
| "-------------- get dataset size {} -----------------".format(num_iter)) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} ---------------------".format(item["label"])) | |||||
| logger.info( | |||||
| "-------------- item[data]: {} ----------------------".format(item["data"])) | |||||
| num_iter += 1 | num_iter += 1 | ||||
| labels.append(item["label"]) | labels.append(item["label"]) | ||||
| assert num_iter == 10 | assert num_iter == 10 | ||||
| @@ -189,15 +293,20 @@ def test_cv_minddataset_batch_size_larger_than_records(add_and_remove_cv_file): | |||||
| num_readers = 4 | num_readers = 4 | ||||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers) | data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers) | ||||
| decode_op = vision.Decode() | decode_op = vision.Decode() | ||||
| data_set = data_set.map(input_columns=["data"], operations=decode_op, num_parallel_workers=2) | |||||
| data_set = data_set.map( | |||||
| input_columns=["data"], operations=decode_op, num_parallel_workers=2) | |||||
| resize_op = vision.Resize((32, 32), interpolation=Inter.LINEAR) | resize_op = vision.Resize((32, 32), interpolation=Inter.LINEAR) | ||||
| data_set = data_set.map(input_columns="data", operations=resize_op, num_parallel_workers=2) | |||||
| data_set = data_set.map(input_columns="data", | |||||
| operations=resize_op, num_parallel_workers=2) | |||||
| data_set = data_set.batch(32, drop_remainder=True) | data_set = data_set.batch(32, drop_remainder=True) | ||||
| num_iter = 0 | num_iter = 0 | ||||
| for item in data_set.create_dict_iterator(): | for item in data_set.create_dict_iterator(): | ||||
| logger.info("-------------- get dataset size {} -----------------".format(num_iter)) | |||||
| logger.info("-------------- item[label]: {} ---------------------".format(item["label"])) | |||||
| logger.info("-------------- item[data]: {} ----------------------".format(item["data"])) | |||||
| logger.info( | |||||
| "-------------- get dataset size {} -----------------".format(num_iter)) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} ---------------------".format(item["label"])) | |||||
| logger.info( | |||||
| "-------------- item[data]: {} ----------------------".format(item["data"])) | |||||
| num_iter += 1 | num_iter += 1 | ||||
| assert num_iter == 0 | assert num_iter == 0 | ||||
| @@ -206,7 +315,8 @@ def test_cv_minddataset_issue_888(add_and_remove_cv_file): | |||||
| """issue 888 test.""" | """issue 888 test.""" | ||||
| columns_list = ["data", "label"] | columns_list = ["data", "label"] | ||||
| num_readers = 2 | num_readers = 2 | ||||
| data = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, shuffle=False, num_shards=5, shard_id=1) | |||||
| data = ds.MindDataset(CV_FILE_NAME + "0", columns_list, | |||||
| num_readers, shuffle=False, num_shards=5, shard_id=1) | |||||
| data = data.shuffle(2) | data = data.shuffle(2) | ||||
| data = data.repeat(9) | data = data.repeat(9) | ||||
| num_iter = 0 | num_iter = 0 | ||||
| @@ -226,9 +336,12 @@ def test_cv_minddataset_blockreader_tutorial(add_and_remove_cv_file): | |||||
| data_set = data_set.repeat(repeat_num) | data_set = data_set.repeat(repeat_num) | ||||
| num_iter = 0 | num_iter = 0 | ||||
| for item in data_set.create_dict_iterator(): | for item in data_set.create_dict_iterator(): | ||||
| logger.info("-------------- block reader repeat tow {} -----------------".format(num_iter)) | |||||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| logger.info( | |||||
| "-------------- block reader repeat tow {} -----------------".format(num_iter)) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| logger.info( | |||||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| num_iter += 1 | num_iter += 1 | ||||
| assert num_iter == 20 | assert num_iter == 20 | ||||
| @@ -244,10 +357,14 @@ def test_cv_minddataset_blockreader_some_field_not_in_index_tutorial(add_and_rem | |||||
| data_set = data_set.repeat(repeat_num) | data_set = data_set.repeat(repeat_num) | ||||
| num_iter = 0 | num_iter = 0 | ||||
| for item in data_set.create_dict_iterator(): | for item in data_set.create_dict_iterator(): | ||||
| logger.info("-------------- block reader repeat tow {} -----------------".format(num_iter)) | |||||
| logger.info("-------------- item[id]: {} ----------------------------".format(item["id"])) | |||||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| logger.info( | |||||
| "-------------- block reader repeat tow {} -----------------".format(num_iter)) | |||||
| logger.info( | |||||
| "-------------- item[id]: {} ----------------------------".format(item["id"])) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| logger.info( | |||||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| num_iter += 1 | num_iter += 1 | ||||
| assert num_iter == 20 | assert num_iter == 20 | ||||
| @@ -256,15 +373,21 @@ def test_cv_minddataset_reader_file_list(add_and_remove_cv_file): | |||||
| """tutorial for cv minderdataset.""" | """tutorial for cv minderdataset.""" | ||||
| columns_list = ["data", "file_name", "label"] | columns_list = ["data", "file_name", "label"] | ||||
| num_readers = 4 | num_readers = 4 | ||||
| data_set = ds.MindDataset([CV_FILE_NAME + str(x) for x in range(FILES_NUM)], columns_list, num_readers) | |||||
| data_set = ds.MindDataset([CV_FILE_NAME + str(x) | |||||
| for x in range(FILES_NUM)], columns_list, num_readers) | |||||
| assert data_set.get_dataset_size() == 10 | assert data_set.get_dataset_size() == 10 | ||||
| num_iter = 0 | num_iter = 0 | ||||
| for item in data_set.create_dict_iterator(): | for item in data_set.create_dict_iterator(): | ||||
| logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) | |||||
| logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| logger.info( | |||||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info( | |||||
| "-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) | |||||
| logger.info( | |||||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| logger.info( | |||||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| num_iter += 1 | num_iter += 1 | ||||
| assert num_iter == 10 | assert num_iter == 10 | ||||
| @@ -277,11 +400,16 @@ def test_cv_minddataset_reader_one_partition(add_and_remove_cv_file): | |||||
| assert data_set.get_dataset_size() < 10 | assert data_set.get_dataset_size() < 10 | ||||
| num_iter = 0 | num_iter = 0 | ||||
| for item in data_set.create_dict_iterator(): | for item in data_set.create_dict_iterator(): | ||||
| logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) | |||||
| logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| logger.info( | |||||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info( | |||||
| "-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) | |||||
| logger.info( | |||||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| logger.info( | |||||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| num_iter += 1 | num_iter += 1 | ||||
| assert num_iter < 10 | assert num_iter < 10 | ||||
| @@ -324,11 +452,16 @@ def test_cv_minddataset_reader_two_dataset(add_and_remove_cv_file): | |||||
| assert data_set.get_dataset_size() == 30 | assert data_set.get_dataset_size() == 30 | ||||
| num_iter = 0 | num_iter = 0 | ||||
| for item in data_set.create_dict_iterator(): | for item in data_set.create_dict_iterator(): | ||||
| logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) | |||||
| logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| logger.info( | |||||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info( | |||||
| "-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) | |||||
| logger.info( | |||||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| logger.info( | |||||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| num_iter += 1 | num_iter += 1 | ||||
| assert num_iter == 30 | assert num_iter == 30 | ||||
| if os.path.exists(CV1_FILE_NAME): | if os.path.exists(CV1_FILE_NAME): | ||||
| @@ -346,7 +479,8 @@ def test_cv_minddataset_reader_two_dataset_partition(add_and_remove_cv_file): | |||||
| for x in range(FILES_NUM)] | for x in range(FILES_NUM)] | ||||
| for x in paths: | for x in paths: | ||||
| os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None | os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None | ||||
| os.remove("{}.db".format(x)) if os.path.exists("{}.db".format(x)) else None | |||||
| os.remove("{}.db".format(x)) if os.path.exists( | |||||
| "{}.db".format(x)) else None | |||||
| writer = FileWriter(CV1_FILE_NAME, FILES_NUM) | writer = FileWriter(CV1_FILE_NAME, FILES_NUM) | ||||
| data = get_data(CV_DIR_NAME) | data = get_data(CV_DIR_NAME) | ||||
| cv_schema_json = {"id": {"type": "int32"}, | cv_schema_json = {"id": {"type": "int32"}, | ||||
| @@ -365,11 +499,16 @@ def test_cv_minddataset_reader_two_dataset_partition(add_and_remove_cv_file): | |||||
| assert data_set.get_dataset_size() < 20 | assert data_set.get_dataset_size() < 20 | ||||
| num_iter = 0 | num_iter = 0 | ||||
| for item in data_set.create_dict_iterator(): | for item in data_set.create_dict_iterator(): | ||||
| logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) | |||||
| logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| logger.info( | |||||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info( | |||||
| "-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) | |||||
| logger.info( | |||||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| logger.info( | |||||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| num_iter += 1 | num_iter += 1 | ||||
| assert num_iter < 20 | assert num_iter < 20 | ||||
| for x in paths: | for x in paths: | ||||
| @@ -385,11 +524,16 @@ def test_cv_minddataset_reader_basic_tutorial(add_and_remove_cv_file): | |||||
| assert data_set.get_dataset_size() == 10 | assert data_set.get_dataset_size() == 10 | ||||
| num_iter = 0 | num_iter = 0 | ||||
| for item in data_set.create_dict_iterator(): | for item in data_set.create_dict_iterator(): | ||||
| logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) | |||||
| logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| logger.info( | |||||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info( | |||||
| "-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) | |||||
| logger.info( | |||||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| logger.info( | |||||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| num_iter += 1 | num_iter += 1 | ||||
| assert num_iter == 10 | assert num_iter == 10 | ||||
| @@ -401,10 +545,14 @@ def test_nlp_minddataset_reader_basic_tutorial(add_and_remove_nlp_file): | |||||
| assert data_set.get_dataset_size() == 10 | assert data_set.get_dataset_size() == 10 | ||||
| num_iter = 0 | num_iter = 0 | ||||
| for item in data_set.create_dict_iterator(): | for item in data_set.create_dict_iterator(): | ||||
| logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info("-------------- num_iter: {} ------------------------".format(num_iter)) | |||||
| logger.info("-------------- item[id]: {} ------------------------".format(item["id"])) | |||||
| logger.info("-------------- item[rating]: {} --------------------".format(item["rating"])) | |||||
| logger.info( | |||||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info( | |||||
| "-------------- num_iter: {} ------------------------".format(num_iter)) | |||||
| logger.info( | |||||
| "-------------- item[id]: {} ------------------------".format(item["id"])) | |||||
| logger.info( | |||||
| "-------------- item[rating]: {} --------------------".format(item["rating"])) | |||||
| logger.info("-------------- item[input_ids]: {}, shape: {} -----------------".format( | logger.info("-------------- item[input_ids]: {}, shape: {} -----------------".format( | ||||
| item["input_ids"], item["input_ids"].shape)) | item["input_ids"], item["input_ids"].shape)) | ||||
| logger.info("-------------- item[input_mask]: {}, shape: {} -----------------".format( | logger.info("-------------- item[input_mask]: {}, shape: {} -----------------".format( | ||||
| @@ -445,10 +593,13 @@ def test_cv_minddataset_reader_basic_tutorial_5_epoch_with_batch(add_and_remove_ | |||||
| # define map operations | # define map operations | ||||
| decode_op = vision.Decode() | decode_op = vision.Decode() | ||||
| resize_op = vision.Resize((resize_height, resize_width), ds.transforms.vision.Inter.LINEAR) | |||||
| resize_op = vision.Resize( | |||||
| (resize_height, resize_width), ds.transforms.vision.Inter.LINEAR) | |||||
| data_set = data_set.map(input_columns=["data"], operations=decode_op, num_parallel_workers=4) | |||||
| data_set = data_set.map(input_columns=["data"], operations=resize_op, num_parallel_workers=4) | |||||
| data_set = data_set.map( | |||||
| input_columns=["data"], operations=decode_op, num_parallel_workers=4) | |||||
| data_set = data_set.map( | |||||
| input_columns=["data"], operations=resize_op, num_parallel_workers=4) | |||||
| data_set = data_set.batch(2) | data_set = data_set.batch(2) | ||||
| assert data_set.get_dataset_size() == 5 | assert data_set.get_dataset_size() == 5 | ||||
| @@ -468,11 +619,16 @@ def test_cv_minddataset_reader_no_columns(add_and_remove_cv_file): | |||||
| assert data_set.get_dataset_size() == 10 | assert data_set.get_dataset_size() == 10 | ||||
| num_iter = 0 | num_iter = 0 | ||||
| for item in data_set.create_dict_iterator(): | for item in data_set.create_dict_iterator(): | ||||
| logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) | |||||
| logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| logger.info( | |||||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info( | |||||
| "-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) | |||||
| logger.info( | |||||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| logger.info( | |||||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| num_iter += 1 | num_iter += 1 | ||||
| assert num_iter == 10 | assert num_iter == 10 | ||||
| @@ -486,11 +642,16 @@ def test_cv_minddataset_reader_repeat_tutorial(add_and_remove_cv_file): | |||||
| data_set = data_set.repeat(repeat_num) | data_set = data_set.repeat(repeat_num) | ||||
| num_iter = 0 | num_iter = 0 | ||||
| for item in data_set.create_dict_iterator(): | for item in data_set.create_dict_iterator(): | ||||
| logger.info("-------------- repeat two test {} ------------------------".format(num_iter)) | |||||
| logger.info("-------------- len(item[data]): {} -----------------------".format(len(item["data"]))) | |||||
| logger.info("-------------- item[data]: {} ----------------------------".format(item["data"])) | |||||
| logger.info("-------------- item[file_name]: {} -----------------------".format(item["file_name"])) | |||||
| logger.info("-------------- item[label]: {} ---------------------------".format(item["label"])) | |||||
| logger.info( | |||||
| "-------------- repeat two test {} ------------------------".format(num_iter)) | |||||
| logger.info( | |||||
| "-------------- len(item[data]): {} -----------------------".format(len(item["data"]))) | |||||
| logger.info( | |||||
| "-------------- item[data]: {} ----------------------------".format(item["data"])) | |||||
| logger.info( | |||||
| "-------------- item[file_name]: {} -----------------------".format(item["file_name"])) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} ---------------------------".format(item["label"])) | |||||
| num_iter += 1 | num_iter += 1 | ||||
| assert num_iter == 20 | assert num_iter == 20 | ||||
| @@ -599,7 +760,8 @@ def get_mkv_data(dir_name): | |||||
| "id": index} | "id": index} | ||||
| data_list.append(data_json) | data_list.append(data_json) | ||||
| index += 1 | index += 1 | ||||
| logger.info('{} images are missing'.format(len(file_list) - len(data_list))) | |||||
| logger.info('{} images are missing'.format( | |||||
| len(file_list) - len(data_list))) | |||||
| return data_list | return data_list | ||||
| @@ -686,6 +848,10 @@ def inputs(vectors, maxlen=50): | |||||
| def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): | def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): | ||||
| mindrecord_file_name = "test.mindrecord" | mindrecord_file_name = "test.mindrecord" | ||||
| if os.path.exists("{}".format(mindrecord_file_name)): | |||||
| os.remove("{}".format(mindrecord_file_name)) | |||||
| if os.path.exists("{}.db".format(mindrecord_file_name)): | |||||
| os.remove("{}.db".format(x)) | |||||
| data = [{"file_name": "001.jpg", "label": 4, | data = [{"file_name": "001.jpg", "label": 4, | ||||
| "image1": bytes("image1 bytes abc", encoding='UTF-8'), | "image1": bytes("image1 bytes abc", encoding='UTF-8'), | ||||
| "image2": bytes("image1 bytes def", encoding='UTF-8'), | "image2": bytes("image1 bytes def", encoding='UTF-8'), | ||||
| @@ -782,7 +948,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): | |||||
| data_value_to_list = [] | data_value_to_list = [] | ||||
| for item in data: | for item in data: | ||||
| new_data = {} | new_data = {} | ||||
| new_data['file_name'] = np.asarray(list(bytes(item["file_name"], encoding='utf-8')), dtype=np.uint8) | |||||
| new_data['file_name'] = np.asarray( | |||||
| list(bytes(item["file_name"], encoding='utf-8')), dtype=np.uint8) | |||||
| new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32) | new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32) | ||||
| new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8) | new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8) | ||||
| new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8) | new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8) | ||||
| @@ -807,7 +974,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): | |||||
| assert len(item) == 13 | assert len(item) == 13 | ||||
| for field in item: | for field in item: | ||||
| if isinstance(item[field], np.ndarray): | if isinstance(item[field], np.ndarray): | ||||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||||
| assert (item[field] == | |||||
| data_value_to_list[num_iter][field]).all() | |||||
| else: | else: | ||||
| assert item[field] == data_value_to_list[num_iter][field] | assert item[field] == data_value_to_list[num_iter][field] | ||||
| num_iter += 1 | num_iter += 1 | ||||
| @@ -815,7 +983,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): | |||||
| num_readers = 2 | num_readers = 2 | ||||
| data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | ||||
| columns_list=["source_sos_ids", "source_sos_mask", "target_sos_ids"], | |||||
| columns_list=["source_sos_ids", | |||||
| "source_sos_mask", "target_sos_ids"], | |||||
| num_parallel_workers=num_readers, | num_parallel_workers=num_readers, | ||||
| shuffle=False) | shuffle=False) | ||||
| assert data_set.get_dataset_size() == 6 | assert data_set.get_dataset_size() == 6 | ||||
| @@ -832,7 +1001,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): | |||||
| num_readers = 1 | num_readers = 1 | ||||
| data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | ||||
| columns_list=["image2", "source_sos_mask", "image3", "target_sos_ids"], | |||||
| columns_list=[ | |||||
| "image2", "source_sos_mask", "image3", "target_sos_ids"], | |||||
| num_parallel_workers=num_readers, | num_parallel_workers=num_readers, | ||||
| shuffle=False) | shuffle=False) | ||||
| assert data_set.get_dataset_size() == 6 | assert data_set.get_dataset_size() == 6 | ||||
| @@ -841,7 +1011,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): | |||||
| assert len(item) == 4 | assert len(item) == 4 | ||||
| for field in item: | for field in item: | ||||
| if isinstance(item[field], np.ndarray): | if isinstance(item[field], np.ndarray): | ||||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||||
| assert (item[field] == | |||||
| data_value_to_list[num_iter][field]).all() | |||||
| else: | else: | ||||
| assert item[field] == data_value_to_list[num_iter][field] | assert item[field] == data_value_to_list[num_iter][field] | ||||
| num_iter += 1 | num_iter += 1 | ||||
| @@ -849,7 +1020,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): | |||||
| num_readers = 3 | num_readers = 3 | ||||
| data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | ||||
| columns_list=["target_sos_ids", "image4", "source_sos_ids"], | |||||
| columns_list=["target_sos_ids", | |||||
| "image4", "source_sos_ids"], | |||||
| num_parallel_workers=num_readers, | num_parallel_workers=num_readers, | ||||
| shuffle=False) | shuffle=False) | ||||
| assert data_set.get_dataset_size() == 6 | assert data_set.get_dataset_size() == 6 | ||||
| @@ -858,7 +1030,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): | |||||
| assert len(item) == 3 | assert len(item) == 3 | ||||
| for field in item: | for field in item: | ||||
| if isinstance(item[field], np.ndarray): | if isinstance(item[field], np.ndarray): | ||||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||||
| assert (item[field] == | |||||
| data_value_to_list[num_iter][field]).all() | |||||
| else: | else: | ||||
| assert item[field] == data_value_to_list[num_iter][field] | assert item[field] == data_value_to_list[num_iter][field] | ||||
| num_iter += 1 | num_iter += 1 | ||||
| @@ -866,7 +1039,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): | |||||
| num_readers = 3 | num_readers = 3 | ||||
| data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | ||||
| columns_list=["target_sos_ids", "image5", "image4", "image3", "source_sos_ids"], | |||||
| columns_list=["target_sos_ids", "image5", | |||||
| "image4", "image3", "source_sos_ids"], | |||||
| num_parallel_workers=num_readers, | num_parallel_workers=num_readers, | ||||
| shuffle=False) | shuffle=False) | ||||
| assert data_set.get_dataset_size() == 6 | assert data_set.get_dataset_size() == 6 | ||||
| @@ -875,7 +1049,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): | |||||
| assert len(item) == 5 | assert len(item) == 5 | ||||
| for field in item: | for field in item: | ||||
| if isinstance(item[field], np.ndarray): | if isinstance(item[field], np.ndarray): | ||||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||||
| assert (item[field] == | |||||
| data_value_to_list[num_iter][field]).all() | |||||
| else: | else: | ||||
| assert item[field] == data_value_to_list[num_iter][field] | assert item[field] == data_value_to_list[num_iter][field] | ||||
| num_iter += 1 | num_iter += 1 | ||||
| @@ -883,7 +1058,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): | |||||
| num_readers = 1 | num_readers = 1 | ||||
| data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | ||||
| columns_list=["target_eos_mask", "image5", "image2", "source_sos_mask", "label"], | |||||
| columns_list=["target_eos_mask", "image5", | |||||
| "image2", "source_sos_mask", "label"], | |||||
| num_parallel_workers=num_readers, | num_parallel_workers=num_readers, | ||||
| shuffle=False) | shuffle=False) | ||||
| assert data_set.get_dataset_size() == 6 | assert data_set.get_dataset_size() == 6 | ||||
| @@ -892,7 +1068,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): | |||||
| assert len(item) == 5 | assert len(item) == 5 | ||||
| for field in item: | for field in item: | ||||
| if isinstance(item[field], np.ndarray): | if isinstance(item[field], np.ndarray): | ||||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||||
| assert (item[field] == | |||||
| data_value_to_list[num_iter][field]).all() | |||||
| else: | else: | ||||
| assert item[field] == data_value_to_list[num_iter][field] | assert item[field] == data_value_to_list[num_iter][field] | ||||
| num_iter += 1 | num_iter += 1 | ||||
| @@ -910,7 +1087,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): | |||||
| assert len(item) == 11 | assert len(item) == 11 | ||||
| for field in item: | for field in item: | ||||
| if isinstance(item[field], np.ndarray): | if isinstance(item[field], np.ndarray): | ||||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||||
| assert (item[field] == | |||||
| data_value_to_list[num_iter][field]).all() | |||||
| else: | else: | ||||
| assert item[field] == data_value_to_list[num_iter][field] | assert item[field] == data_value_to_list[num_iter][field] | ||||
| num_iter += 1 | num_iter += 1 | ||||
| @@ -975,7 +1153,8 @@ def test_write_with_multi_bytes_and_MindDataset(): | |||||
| data_value_to_list = [] | data_value_to_list = [] | ||||
| for item in data: | for item in data: | ||||
| new_data = {} | new_data = {} | ||||
| new_data['file_name'] = np.asarray(list(bytes(item["file_name"], encoding='utf-8')), dtype=np.uint8) | |||||
| new_data['file_name'] = np.asarray( | |||||
| list(bytes(item["file_name"], encoding='utf-8')), dtype=np.uint8) | |||||
| new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32) | new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32) | ||||
| new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8) | new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8) | ||||
| new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8) | new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8) | ||||
| @@ -994,7 +1173,8 @@ def test_write_with_multi_bytes_and_MindDataset(): | |||||
| assert len(item) == 7 | assert len(item) == 7 | ||||
| for field in item: | for field in item: | ||||
| if isinstance(item[field], np.ndarray): | if isinstance(item[field], np.ndarray): | ||||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||||
| assert (item[field] == | |||||
| data_value_to_list[num_iter][field]).all() | |||||
| else: | else: | ||||
| assert item[field] == data_value_to_list[num_iter][field] | assert item[field] == data_value_to_list[num_iter][field] | ||||
| num_iter += 1 | num_iter += 1 | ||||
| @@ -1011,7 +1191,8 @@ def test_write_with_multi_bytes_and_MindDataset(): | |||||
| assert len(item) == 3 | assert len(item) == 3 | ||||
| for field in item: | for field in item: | ||||
| if isinstance(item[field], np.ndarray): | if isinstance(item[field], np.ndarray): | ||||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||||
| assert (item[field] == | |||||
| data_value_to_list[num_iter][field]).all() | |||||
| else: | else: | ||||
| assert item[field] == data_value_to_list[num_iter][field] | assert item[field] == data_value_to_list[num_iter][field] | ||||
| num_iter += 1 | num_iter += 1 | ||||
| @@ -1028,7 +1209,8 @@ def test_write_with_multi_bytes_and_MindDataset(): | |||||
| assert len(item) == 2 | assert len(item) == 2 | ||||
| for field in item: | for field in item: | ||||
| if isinstance(item[field], np.ndarray): | if isinstance(item[field], np.ndarray): | ||||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||||
| assert (item[field] == | |||||
| data_value_to_list[num_iter][field]).all() | |||||
| else: | else: | ||||
| assert item[field] == data_value_to_list[num_iter][field] | assert item[field] == data_value_to_list[num_iter][field] | ||||
| num_iter += 1 | num_iter += 1 | ||||
| @@ -1045,7 +1227,8 @@ def test_write_with_multi_bytes_and_MindDataset(): | |||||
| assert len(item) == 2 | assert len(item) == 2 | ||||
| for field in item: | for field in item: | ||||
| if isinstance(item[field], np.ndarray): | if isinstance(item[field], np.ndarray): | ||||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||||
| assert (item[field] == | |||||
| data_value_to_list[num_iter][field]).all() | |||||
| else: | else: | ||||
| assert item[field] == data_value_to_list[num_iter][field] | assert item[field] == data_value_to_list[num_iter][field] | ||||
| num_iter += 1 | num_iter += 1 | ||||
| @@ -1062,7 +1245,8 @@ def test_write_with_multi_bytes_and_MindDataset(): | |||||
| assert len(item) == 3 | assert len(item) == 3 | ||||
| for field in item: | for field in item: | ||||
| if isinstance(item[field], np.ndarray): | if isinstance(item[field], np.ndarray): | ||||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||||
| assert (item[field] == | |||||
| data_value_to_list[num_iter][field]).all() | |||||
| else: | else: | ||||
| assert item[field] == data_value_to_list[num_iter][field] | assert item[field] == data_value_to_list[num_iter][field] | ||||
| num_iter += 1 | num_iter += 1 | ||||
| @@ -1070,7 +1254,8 @@ def test_write_with_multi_bytes_and_MindDataset(): | |||||
| num_readers = 2 | num_readers = 2 | ||||
| data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | ||||
| columns_list=["image4", "image5", "image2", "image3", "file_name"], | |||||
| columns_list=["image4", "image5", | |||||
| "image2", "image3", "file_name"], | |||||
| num_parallel_workers=num_readers, | num_parallel_workers=num_readers, | ||||
| shuffle=False) | shuffle=False) | ||||
| assert data_set.get_dataset_size() == 6 | assert data_set.get_dataset_size() == 6 | ||||
| @@ -1079,7 +1264,8 @@ def test_write_with_multi_bytes_and_MindDataset(): | |||||
| assert len(item) == 5 | assert len(item) == 5 | ||||
| for field in item: | for field in item: | ||||
| if isinstance(item[field], np.ndarray): | if isinstance(item[field], np.ndarray): | ||||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||||
| assert (item[field] == | |||||
| data_value_to_list[num_iter][field]).all() | |||||
| else: | else: | ||||
| assert item[field] == data_value_to_list[num_iter][field] | assert item[field] == data_value_to_list[num_iter][field] | ||||
| num_iter += 1 | num_iter += 1 | ||||
| @@ -1177,7 +1363,8 @@ def test_write_with_multi_array_and_MindDataset(): | |||||
| assert len(item) == 8 | assert len(item) == 8 | ||||
| for field in item: | for field in item: | ||||
| if isinstance(item[field], np.ndarray): | if isinstance(item[field], np.ndarray): | ||||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||||
| assert (item[field] == | |||||
| data_value_to_list[num_iter][field]).all() | |||||
| else: | else: | ||||
| assert item[field] == data_value_to_list[num_iter][field] | assert item[field] == data_value_to_list[num_iter][field] | ||||
| num_iter += 1 | num_iter += 1 | ||||
| @@ -1196,7 +1383,8 @@ def test_write_with_multi_array_and_MindDataset(): | |||||
| assert len(item) == 6 | assert len(item) == 6 | ||||
| for field in item: | for field in item: | ||||
| if isinstance(item[field], np.ndarray): | if isinstance(item[field], np.ndarray): | ||||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||||
| assert (item[field] == | |||||
| data_value_to_list[num_iter][field]).all() | |||||
| else: | else: | ||||
| assert item[field] == data_value_to_list[num_iter][field] | assert item[field] == data_value_to_list[num_iter][field] | ||||
| num_iter += 1 | num_iter += 1 | ||||
| @@ -1215,7 +1403,8 @@ def test_write_with_multi_array_and_MindDataset(): | |||||
| assert len(item) == 3 | assert len(item) == 3 | ||||
| for field in item: | for field in item: | ||||
| if isinstance(item[field], np.ndarray): | if isinstance(item[field], np.ndarray): | ||||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||||
| assert (item[field] == | |||||
| data_value_to_list[num_iter][field]).all() | |||||
| else: | else: | ||||
| assert item[field] == data_value_to_list[num_iter][field] | assert item[field] == data_value_to_list[num_iter][field] | ||||
| num_iter += 1 | num_iter += 1 | ||||
| @@ -1234,7 +1423,8 @@ def test_write_with_multi_array_and_MindDataset(): | |||||
| assert len(item) == 3 | assert len(item) == 3 | ||||
| for field in item: | for field in item: | ||||
| if isinstance(item[field], np.ndarray): | if isinstance(item[field], np.ndarray): | ||||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||||
| assert (item[field] == | |||||
| data_value_to_list[num_iter][field]).all() | |||||
| else: | else: | ||||
| assert item[field] == data_value_to_list[num_iter][field] | assert item[field] == data_value_to_list[num_iter][field] | ||||
| num_iter += 1 | num_iter += 1 | ||||
| @@ -1251,7 +1441,8 @@ def test_write_with_multi_array_and_MindDataset(): | |||||
| assert len(item) == 1 | assert len(item) == 1 | ||||
| for field in item: | for field in item: | ||||
| if isinstance(item[field], np.ndarray): | if isinstance(item[field], np.ndarray): | ||||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||||
| assert (item[field] == | |||||
| data_value_to_list[num_iter][field]).all() | |||||
| else: | else: | ||||
| assert item[field] == data_value_to_list[num_iter][field] | assert item[field] == data_value_to_list[num_iter][field] | ||||
| num_iter += 1 | num_iter += 1 | ||||
| @@ -1271,7 +1462,8 @@ def test_write_with_multi_array_and_MindDataset(): | |||||
| assert len(item) == 8 | assert len(item) == 8 | ||||
| for field in item: | for field in item: | ||||
| if isinstance(item[field], np.ndarray): | if isinstance(item[field], np.ndarray): | ||||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||||
| assert (item[field] == | |||||
| data_value_to_list[num_iter][field]).all() | |||||
| else: | else: | ||||
| assert item[field] == data_value_to_list[num_iter][field] | assert item[field] == data_value_to_list[num_iter][field] | ||||
| num_iter += 1 | num_iter += 1 | ||||
| @@ -25,8 +25,24 @@ from mindspore.mindrecord import SUCCESS | |||||
| CIFAR100_DIR = "../data/mindrecord/testCifar100Data" | CIFAR100_DIR = "../data/mindrecord/testCifar100Data" | ||||
| MINDRECORD_FILE = "./cifar100.mindrecord" | MINDRECORD_FILE = "./cifar100.mindrecord" | ||||
| def test_cifar100_to_mindrecord_without_index_fields(): | |||||
| @pytest.fixture | |||||
| def fixture_file(): | |||||
| """add/remove file""" | |||||
| def remove_file(x): | |||||
| if os.path.exists("{}".format(x)): | |||||
| os.remove("{}".format(x)) | |||||
| if os.path.exists("{}.db".format(x)): | |||||
| os.remove("{}.db".format(x)) | |||||
| if os.path.exists("{}_test".format(x)): | |||||
| os.remove("{}_test".format(x)) | |||||
| if os.path.exists("{}_test.db".format(x)): | |||||
| os.remove("{}_test.db".format(x)) | |||||
| remove_file(MINDRECORD_FILE) | |||||
| yield "yield_fixture_data" | |||||
| remove_file(MINDRECORD_FILE) | |||||
| def test_cifar100_to_mindrecord_without_index_fields(fixture_file): | |||||
| """test transform cifar100 dataset to mindrecord without index fields.""" | """test transform cifar100 dataset to mindrecord without index fields.""" | ||||
| cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, MINDRECORD_FILE) | cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, MINDRECORD_FILE) | ||||
| ret = cifar100_transformer.transform() | ret = cifar100_transformer.transform() | ||||
| @@ -34,25 +50,14 @@ def test_cifar100_to_mindrecord_without_index_fields(): | |||||
| assert os.path.exists(MINDRECORD_FILE) | assert os.path.exists(MINDRECORD_FILE) | ||||
| assert os.path.exists(MINDRECORD_FILE + "_test") | assert os.path.exists(MINDRECORD_FILE + "_test") | ||||
| read() | read() | ||||
| os.remove("{}".format(MINDRECORD_FILE)) | |||||
| os.remove("{}.db".format(MINDRECORD_FILE)) | |||||
| os.remove("{}".format(MINDRECORD_FILE + "_test")) | |||||
| os.remove("{}.db".format(MINDRECORD_FILE + "_test")) | |||||
| def test_cifar100_to_mindrecord(): | |||||
| def test_cifar100_to_mindrecord(fixture_file): | |||||
| """test transform cifar100 dataset to mindrecord.""" | """test transform cifar100 dataset to mindrecord.""" | ||||
| cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, MINDRECORD_FILE) | cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, MINDRECORD_FILE) | ||||
| cifar100_transformer.transform(['fine_label', 'coarse_label']) | cifar100_transformer.transform(['fine_label', 'coarse_label']) | ||||
| assert os.path.exists(MINDRECORD_FILE) | assert os.path.exists(MINDRECORD_FILE) | ||||
| assert os.path.exists(MINDRECORD_FILE + "_test") | assert os.path.exists(MINDRECORD_FILE + "_test") | ||||
| read() | read() | ||||
| os.remove("{}".format(MINDRECORD_FILE)) | |||||
| os.remove("{}.db".format(MINDRECORD_FILE)) | |||||
| os.remove("{}".format(MINDRECORD_FILE + "_test")) | |||||
| os.remove("{}.db".format(MINDRECORD_FILE + "_test")) | |||||
| def read(): | def read(): | ||||
| @@ -77,8 +82,7 @@ def read(): | |||||
| assert count == 4 | assert count == 4 | ||||
| reader.close() | reader.close() | ||||
| def test_cifar100_to_mindrecord_illegal_file_name(): | |||||
| def test_cifar100_to_mindrecord_illegal_file_name(fixture_file): | |||||
| """ | """ | ||||
| test transform cifar100 dataset to mindrecord | test transform cifar100 dataset to mindrecord | ||||
| when file name contains illegal character. | when file name contains illegal character. | ||||
| @@ -88,8 +92,7 @@ def test_cifar100_to_mindrecord_illegal_file_name(): | |||||
| cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, filename) | cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, filename) | ||||
| cifar100_transformer.transform() | cifar100_transformer.transform() | ||||
| def test_cifar100_to_mindrecord_filename_start_with_space(): | |||||
| def test_cifar100_to_mindrecord_filename_start_with_space(fixture_file): | |||||
| """ | """ | ||||
| test transform cifar10 dataset to mindrecord | test transform cifar10 dataset to mindrecord | ||||
| when file name starts with space. | when file name starts with space. | ||||
| @@ -100,8 +103,7 @@ def test_cifar100_to_mindrecord_filename_start_with_space(): | |||||
| cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, filename) | cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, filename) | ||||
| cifar100_transformer.transform() | cifar100_transformer.transform() | ||||
| def test_cifar100_to_mindrecord_filename_contain_space(): | |||||
| def test_cifar100_to_mindrecord_filename_contain_space(fixture_file): | |||||
| """ | """ | ||||
| test transform cifar10 dataset to mindrecord | test transform cifar10 dataset to mindrecord | ||||
| when file name contains space. | when file name contains space. | ||||
| @@ -111,14 +113,8 @@ def test_cifar100_to_mindrecord_filename_contain_space(): | |||||
| cifar100_transformer.transform() | cifar100_transformer.transform() | ||||
| assert os.path.exists(filename) | assert os.path.exists(filename) | ||||
| assert os.path.exists(filename + "_test") | assert os.path.exists(filename + "_test") | ||||
| os.remove("{}".format(filename)) | |||||
| os.remove("{}.db".format(filename)) | |||||
| os.remove("{}".format(filename + "_test")) | |||||
| os.remove("{}.db".format(filename + "_test")) | |||||
| def test_cifar100_to_mindrecord_directory(): | |||||
| def test_cifar100_to_mindrecord_directory(fixture_file): | |||||
| """ | """ | ||||
| test transform cifar10 dataset to mindrecord | test transform cifar10 dataset to mindrecord | ||||
| when destination path is directory. | when destination path is directory. | ||||
| @@ -129,8 +125,7 @@ def test_cifar100_to_mindrecord_directory(): | |||||
| CIFAR100_DIR) | CIFAR100_DIR) | ||||
| cifar100_transformer.transform() | cifar100_transformer.transform() | ||||
| def test_cifar100_to_mindrecord_filename_equals_cifar100(): | |||||
| def test_cifar100_to_mindrecord_filename_equals_cifar100(fixture_file): | |||||
| """ | """ | ||||
| test transform cifar10 dataset to mindrecord | test transform cifar10 dataset to mindrecord | ||||
| when destination path equals source path. | when destination path equals source path. | ||||
| @@ -24,36 +24,60 @@ from mindspore.mindrecord import MRMOpenError, SUCCESS | |||||
| CIFAR10_DIR = "../data/mindrecord/testCifar10Data" | CIFAR10_DIR = "../data/mindrecord/testCifar10Data" | ||||
| MINDRECORD_FILE = "./cifar10.mindrecord" | MINDRECORD_FILE = "./cifar10.mindrecord" | ||||
| def test_cifar10_to_mindrecord_without_index_fields(): | |||||
| @pytest.fixture | |||||
| def fixture_file(): | |||||
| """add/remove file""" | |||||
| def remove_file(x): | |||||
| if os.path.exists("{}".format(x)): | |||||
| os.remove("{}".format(x)) | |||||
| if os.path.exists("{}.db".format(x)): | |||||
| os.remove("{}.db".format(x)) | |||||
| if os.path.exists("{}_test".format(x)): | |||||
| os.remove("{}_test".format(x)) | |||||
| if os.path.exists("{}_test.db".format(x)): | |||||
| os.remove("{}_test.db".format(x)) | |||||
| remove_file(MINDRECORD_FILE) | |||||
| yield "yield_fixture_data" | |||||
| remove_file(MINDRECORD_FILE) | |||||
| @pytest.fixture | |||||
| def fixture_space_file(): | |||||
| """add/remove file""" | |||||
| def remove_file(x): | |||||
| if os.path.exists("{}".format(x)): | |||||
| os.remove("{}".format(x)) | |||||
| if os.path.exists("{}.db".format(x)): | |||||
| os.remove("{}.db".format(x)) | |||||
| if os.path.exists("{}_test".format(x)): | |||||
| os.remove("{}_test".format(x)) | |||||
| if os.path.exists("{}_test.db".format(x)): | |||||
| os.remove("{}_test.db".format(x)) | |||||
| x = "./yes ok" | |||||
| remove_file(x) | |||||
| yield "yield_fixture_data" | |||||
| remove_file(x) | |||||
| def test_cifar10_to_mindrecord_without_index_fields(fixture_file): | |||||
| """test transform cifar10 dataset to mindrecord without index fields.""" | """test transform cifar10 dataset to mindrecord without index fields.""" | ||||
| cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE) | cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE) | ||||
| cifar10_transformer.transform() | cifar10_transformer.transform() | ||||
| assert os.path.exists(MINDRECORD_FILE) | assert os.path.exists(MINDRECORD_FILE) | ||||
| assert os.path.exists(MINDRECORD_FILE + "_test") | assert os.path.exists(MINDRECORD_FILE + "_test") | ||||
| read() | read() | ||||
| os.remove("{}".format(MINDRECORD_FILE)) | |||||
| os.remove("{}.db".format(MINDRECORD_FILE)) | |||||
| os.remove("{}".format(MINDRECORD_FILE + "_test")) | |||||
| os.remove("{}.db".format(MINDRECORD_FILE + "_test")) | |||||
| def test_cifar10_to_mindrecord(): | |||||
| def test_cifar10_to_mindrecord(fixture_file): | |||||
| """test transform cifar10 dataset to mindrecord.""" | """test transform cifar10 dataset to mindrecord.""" | ||||
| cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE) | cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE) | ||||
| cifar10_transformer.transform(['label']) | cifar10_transformer.transform(['label']) | ||||
| assert os.path.exists(MINDRECORD_FILE) | assert os.path.exists(MINDRECORD_FILE) | ||||
| assert os.path.exists(MINDRECORD_FILE + "_test") | assert os.path.exists(MINDRECORD_FILE + "_test") | ||||
| read() | read() | ||||
| os.remove("{}".format(MINDRECORD_FILE)) | |||||
| os.remove("{}.db".format(MINDRECORD_FILE)) | |||||
| os.remove("{}".format(MINDRECORD_FILE + "_test")) | |||||
| os.remove("{}.db".format(MINDRECORD_FILE + "_test")) | |||||
| def test_cifar10_to_mindrecord_with_return(): | |||||
| def test_cifar10_to_mindrecord_with_return(fixture_file): | |||||
| """test transform cifar10 dataset to mindrecord.""" | """test transform cifar10 dataset to mindrecord.""" | ||||
| cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE) | cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE) | ||||
| ret = cifar10_transformer.transform(['label']) | ret = cifar10_transformer.transform(['label']) | ||||
| @@ -61,11 +85,6 @@ def test_cifar10_to_mindrecord_with_return(): | |||||
| assert os.path.exists(MINDRECORD_FILE) | assert os.path.exists(MINDRECORD_FILE) | ||||
| assert os.path.exists(MINDRECORD_FILE + "_test") | assert os.path.exists(MINDRECORD_FILE + "_test") | ||||
| read() | read() | ||||
| os.remove("{}".format(MINDRECORD_FILE)) | |||||
| os.remove("{}.db".format(MINDRECORD_FILE)) | |||||
| os.remove("{}".format(MINDRECORD_FILE + "_test")) | |||||
| os.remove("{}.db".format(MINDRECORD_FILE + "_test")) | |||||
| def read(): | def read(): | ||||
| @@ -90,8 +109,7 @@ def read(): | |||||
| assert count == 4 | assert count == 4 | ||||
| reader.close() | reader.close() | ||||
| def test_cifar10_to_mindrecord_illegal_file_name(): | |||||
| def test_cifar10_to_mindrecord_illegal_file_name(fixture_file): | |||||
| """ | """ | ||||
| test transform cifar10 dataset to mindrecord | test transform cifar10 dataset to mindrecord | ||||
| when file name contains illegal character. | when file name contains illegal character. | ||||
| @@ -101,8 +119,7 @@ def test_cifar10_to_mindrecord_illegal_file_name(): | |||||
| cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, filename) | cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, filename) | ||||
| cifar10_transformer.transform() | cifar10_transformer.transform() | ||||
| def test_cifar10_to_mindrecord_filename_start_with_space(): | |||||
| def test_cifar10_to_mindrecord_filename_start_with_space(fixture_file): | |||||
| """ | """ | ||||
| test transform cifar10 dataset to mindrecord | test transform cifar10 dataset to mindrecord | ||||
| when file name starts with space. | when file name starts with space. | ||||
| @@ -113,8 +130,7 @@ def test_cifar10_to_mindrecord_filename_start_with_space(): | |||||
| cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, filename) | cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, filename) | ||||
| cifar10_transformer.transform() | cifar10_transformer.transform() | ||||
| def test_cifar10_to_mindrecord_filename_contain_space(): | |||||
| def test_cifar10_to_mindrecord_filename_contain_space(fixture_space_file): | |||||
| """ | """ | ||||
| test transform cifar10 dataset to mindrecord | test transform cifar10 dataset to mindrecord | ||||
| when file name contains space. | when file name contains space. | ||||
| @@ -124,14 +140,8 @@ def test_cifar10_to_mindrecord_filename_contain_space(): | |||||
| cifar10_transformer.transform() | cifar10_transformer.transform() | ||||
| assert os.path.exists(filename) | assert os.path.exists(filename) | ||||
| assert os.path.exists(filename + "_test") | assert os.path.exists(filename + "_test") | ||||
| os.remove("{}".format(filename)) | |||||
| os.remove("{}.db".format(filename)) | |||||
| os.remove("{}".format(filename + "_test")) | |||||
| os.remove("{}.db".format(filename + "_test")) | |||||
| def test_cifar10_to_mindrecord_directory(): | |||||
| def test_cifar10_to_mindrecord_directory(fixture_file): | |||||
| """ | """ | ||||
| test transform cifar10 dataset to mindrecord | test transform cifar10 dataset to mindrecord | ||||
| when destination path is directory. | when destination path is directory. | ||||
| @@ -25,6 +25,26 @@ IMAGENET_IMAGE_DIR = "../data/mindrecord/testImageNetDataWhole/images" | |||||
| MINDRECORD_FILE = "../data/mindrecord/testImageNetDataWhole/imagenet.mindrecord" | MINDRECORD_FILE = "../data/mindrecord/testImageNetDataWhole/imagenet.mindrecord" | ||||
| PARTITION_NUMBER = 4 | PARTITION_NUMBER = 4 | ||||
| @pytest.fixture | |||||
| def fixture_file(): | |||||
| """add/remove file""" | |||||
| def remove_one_file(x): | |||||
| if os.path.exists(x): | |||||
| os.remove(x) | |||||
| def remove_file(): | |||||
| x = MINDRECORD_FILE | |||||
| remove_one_file(x) | |||||
| x = MINDRECORD_FILE + ".db" | |||||
| remove_one_file(x) | |||||
| for i in range(PARTITION_NUMBER): | |||||
| x = MINDRECORD_FILE + str(i) | |||||
| remove_one_file(x) | |||||
| x = MINDRECORD_FILE + str(i) + ".db" | |||||
| remove_one_file(x) | |||||
| remove_file() | |||||
| yield "yield_fixture_data" | |||||
| remove_file() | |||||
| def read(filename): | def read(filename): | ||||
| """test file reade""" | """test file reade""" | ||||
| @@ -38,8 +58,7 @@ def read(filename): | |||||
| assert count == 20 | assert count == 20 | ||||
| reader.close() | reader.close() | ||||
| def test_imagenet_to_mindrecord(): | |||||
| def test_imagenet_to_mindrecord(fixture_file): | |||||
| """test transform imagenet dataset to mindrecord.""" | """test transform imagenet dataset to mindrecord.""" | ||||
| imagenet_transformer = ImageNetToMR(IMAGENET_MAP_FILE, IMAGENET_IMAGE_DIR, | imagenet_transformer = ImageNetToMR(IMAGENET_MAP_FILE, IMAGENET_IMAGE_DIR, | ||||
| MINDRECORD_FILE, PARTITION_NUMBER) | MINDRECORD_FILE, PARTITION_NUMBER) | ||||
| @@ -48,12 +67,8 @@ def test_imagenet_to_mindrecord(): | |||||
| assert os.path.exists(MINDRECORD_FILE + str(i)) | assert os.path.exists(MINDRECORD_FILE + str(i)) | ||||
| assert os.path.exists(MINDRECORD_FILE + str(i) + ".db") | assert os.path.exists(MINDRECORD_FILE + str(i) + ".db") | ||||
| read(MINDRECORD_FILE + "0") | read(MINDRECORD_FILE + "0") | ||||
| for i in range(PARTITION_NUMBER): | |||||
| os.remove(MINDRECORD_FILE + str(i)) | |||||
| os.remove(MINDRECORD_FILE + str(i) + ".db") | |||||
| def test_imagenet_to_mindrecord_default_partition_number(): | |||||
| def test_imagenet_to_mindrecord_default_partition_number(fixture_file): | |||||
| """ | """ | ||||
| test transform imagenet dataset to mindrecord | test transform imagenet dataset to mindrecord | ||||
| when partition number is default. | when partition number is default. | ||||
| @@ -64,11 +79,8 @@ def test_imagenet_to_mindrecord_default_partition_number(): | |||||
| assert os.path.exists(MINDRECORD_FILE) | assert os.path.exists(MINDRECORD_FILE) | ||||
| assert os.path.exists(MINDRECORD_FILE + ".db") | assert os.path.exists(MINDRECORD_FILE + ".db") | ||||
| read(MINDRECORD_FILE) | read(MINDRECORD_FILE) | ||||
| os.remove("{}".format(MINDRECORD_FILE)) | |||||
| os.remove("{}.db".format(MINDRECORD_FILE)) | |||||
| def test_imagenet_to_mindrecord_partition_number_0(): | |||||
| def test_imagenet_to_mindrecord_partition_number_0(fixture_file): | |||||
| """ | """ | ||||
| test transform imagenet dataset to mindrecord | test transform imagenet dataset to mindrecord | ||||
| when partition number is 0. | when partition number is 0. | ||||
| @@ -79,8 +91,7 @@ def test_imagenet_to_mindrecord_partition_number_0(): | |||||
| MINDRECORD_FILE, 0) | MINDRECORD_FILE, 0) | ||||
| imagenet_transformer.transform() | imagenet_transformer.transform() | ||||
| def test_imagenet_to_mindrecord_partition_number_none(): | |||||
| def test_imagenet_to_mindrecord_partition_number_none(fixture_file): | |||||
| """ | """ | ||||
| test transform imagenet dataset to mindrecord | test transform imagenet dataset to mindrecord | ||||
| when partition number is none. | when partition number is none. | ||||
| @@ -92,8 +103,7 @@ def test_imagenet_to_mindrecord_partition_number_none(): | |||||
| MINDRECORD_FILE, None) | MINDRECORD_FILE, None) | ||||
| imagenet_transformer.transform() | imagenet_transformer.transform() | ||||
| def test_imagenet_to_mindrecord_illegal_filename(): | |||||
| def test_imagenet_to_mindrecord_illegal_filename(fixture_file): | |||||
| """ | """ | ||||
| test transform imagenet dataset to mindrecord | test transform imagenet dataset to mindrecord | ||||
| when file name contains illegal character. | when file name contains illegal character. | ||||
| @@ -26,6 +26,34 @@ CV_FILE_NAME = "./imagenet.mindrecord" | |||||
| NLP_FILE_NAME = "./aclImdb.mindrecord" | NLP_FILE_NAME = "./aclImdb.mindrecord" | ||||
| FILES_NUM = 4 | FILES_NUM = 4 | ||||
| def remove_one_file(x): | |||||
| if os.path.exists(x): | |||||
| os.remove(x) | |||||
| def remove_file(file_name): | |||||
| x = file_name | |||||
| remove_one_file(x) | |||||
| x = file_name + ".db" | |||||
| remove_one_file(x) | |||||
| for i in range(FILES_NUM): | |||||
| x = file_name + str(i) | |||||
| remove_one_file(x) | |||||
| x = file_name + str(i) + ".db" | |||||
| remove_one_file(x) | |||||
| @pytest.fixture | |||||
| def fixture_cv_file(): | |||||
| """add/remove file""" | |||||
| remove_file(CV_FILE_NAME) | |||||
| yield "yield_fixture_data" | |||||
| remove_file(CV_FILE_NAME) | |||||
| @pytest.fixture | |||||
| def fixture_nlp_file(): | |||||
| """add/remove file""" | |||||
| remove_file(NLP_FILE_NAME) | |||||
| yield "yield_fixture_data" | |||||
| remove_file(NLP_FILE_NAME) | |||||
| def test_cv_file_writer_shard_num_none(): | def test_cv_file_writer_shard_num_none(): | ||||
| """test cv file writer when shard num is None.""" | """test cv file writer when shard num is None.""" | ||||
| @@ -83,8 +111,7 @@ def test_lack_partition_and_db(): | |||||
| 'error_msg: MindRecord File could not open successfully.' \ | 'error_msg: MindRecord File could not open successfully.' \ | ||||
| in str(err.value) | in str(err.value) | ||||
| def test_lack_db(): | |||||
| def test_lack_db(fixture_cv_file): | |||||
| """test file reader when db file does not exist.""" | """test file reader when db file does not exist.""" | ||||
| create_cv_mindrecord(1) | create_cv_mindrecord(1) | ||||
| os.remove("{}.db".format(CV_FILE_NAME)) | os.remove("{}.db".format(CV_FILE_NAME)) | ||||
| @@ -94,10 +121,8 @@ def test_lack_db(): | |||||
| assert '[MRMOpenError]: error_code: 1347690596, ' \ | assert '[MRMOpenError]: error_code: 1347690596, ' \ | ||||
| 'error_msg: MindRecord File could not open successfully.' \ | 'error_msg: MindRecord File could not open successfully.' \ | ||||
| in str(err.value) | in str(err.value) | ||||
| os.remove(CV_FILE_NAME) | |||||
| def test_lack_some_partition_and_db(): | |||||
| def test_lack_some_partition_and_db(fixture_cv_file): | |||||
| """test file reader when some partition and db do not exist.""" | """test file reader when some partition and db do not exist.""" | ||||
| create_cv_mindrecord(4) | create_cv_mindrecord(4) | ||||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | ||||
| @@ -110,16 +135,8 @@ def test_lack_some_partition_and_db(): | |||||
| assert '[MRMOpenError]: error_code: 1347690596, ' \ | assert '[MRMOpenError]: error_code: 1347690596, ' \ | ||||
| 'error_msg: MindRecord File could not open successfully.' \ | 'error_msg: MindRecord File could not open successfully.' \ | ||||
| in str(err.value) | in str(err.value) | ||||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | |||||
| for x in range(FILES_NUM)] | |||||
| for x in paths: | |||||
| if os.path.exists("{}".format(x)): | |||||
| os.remove("{}".format(x)) | |||||
| if os.path.exists("{}.db".format(x)): | |||||
| os.remove("{}.db".format(x)) | |||||
| def test_lack_some_partition_first(): | |||||
| def test_lack_some_partition_first(fixture_cv_file): | |||||
| """test file reader when first partition does not exist.""" | """test file reader when first partition does not exist.""" | ||||
| create_cv_mindrecord(4) | create_cv_mindrecord(4) | ||||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | ||||
| @@ -131,14 +148,8 @@ def test_lack_some_partition_first(): | |||||
| assert '[MRMOpenError]: error_code: 1347690596, ' \ | assert '[MRMOpenError]: error_code: 1347690596, ' \ | ||||
| 'error_msg: MindRecord File could not open successfully.' \ | 'error_msg: MindRecord File could not open successfully.' \ | ||||
| in str(err.value) | in str(err.value) | ||||
| for x in paths: | |||||
| if os.path.exists("{}".format(x)): | |||||
| os.remove("{}".format(x)) | |||||
| if os.path.exists("{}.db".format(x)): | |||||
| os.remove("{}.db".format(x)) | |||||
| def test_lack_some_partition_middle(): | |||||
| def test_lack_some_partition_middle(fixture_cv_file): | |||||
| """test file reader when some partition does not exist.""" | """test file reader when some partition does not exist.""" | ||||
| create_cv_mindrecord(4) | create_cv_mindrecord(4) | ||||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | ||||
| @@ -150,14 +161,8 @@ def test_lack_some_partition_middle(): | |||||
| assert '[MRMOpenError]: error_code: 1347690596, ' \ | assert '[MRMOpenError]: error_code: 1347690596, ' \ | ||||
| 'error_msg: MindRecord File could not open successfully.' \ | 'error_msg: MindRecord File could not open successfully.' \ | ||||
| in str(err.value) | in str(err.value) | ||||
| for x in paths: | |||||
| if os.path.exists("{}".format(x)): | |||||
| os.remove("{}".format(x)) | |||||
| if os.path.exists("{}.db".format(x)): | |||||
| os.remove("{}.db".format(x)) | |||||
| def test_lack_some_partition_last(): | |||||
| def test_lack_some_partition_last(fixture_cv_file): | |||||
| """test file reader when last partition does not exist.""" | """test file reader when last partition does not exist.""" | ||||
| create_cv_mindrecord(4) | create_cv_mindrecord(4) | ||||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | ||||
| @@ -169,14 +174,8 @@ def test_lack_some_partition_last(): | |||||
| assert '[MRMOpenError]: error_code: 1347690596, ' \ | assert '[MRMOpenError]: error_code: 1347690596, ' \ | ||||
| 'error_msg: MindRecord File could not open successfully.' \ | 'error_msg: MindRecord File could not open successfully.' \ | ||||
| in str(err.value) | in str(err.value) | ||||
| for x in paths: | |||||
| if os.path.exists("{}".format(x)): | |||||
| os.remove("{}".format(x)) | |||||
| if os.path.exists("{}.db".format(x)): | |||||
| os.remove("{}.db".format(x)) | |||||
| def test_mindpage_lack_some_partition(): | |||||
| def test_mindpage_lack_some_partition(fixture_cv_file): | |||||
| """test page reader when some partition does not exist.""" | """test page reader when some partition does not exist.""" | ||||
| create_cv_mindrecord(4) | create_cv_mindrecord(4) | ||||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | ||||
| @@ -187,14 +186,8 @@ def test_mindpage_lack_some_partition(): | |||||
| assert '[MRMOpenError]: error_code: 1347690596, ' \ | assert '[MRMOpenError]: error_code: 1347690596, ' \ | ||||
| 'error_msg: MindRecord File could not open successfully.' \ | 'error_msg: MindRecord File could not open successfully.' \ | ||||
| in str(err.value) | in str(err.value) | ||||
| for x in paths: | |||||
| if os.path.exists("{}".format(x)): | |||||
| os.remove("{}".format(x)) | |||||
| if os.path.exists("{}.db".format(x)): | |||||
| os.remove("{}.db".format(x)) | |||||
| def test_lack_some_db(): | |||||
| def test_lack_some_db(fixture_cv_file): | |||||
| """test file reader when some db does not exist.""" | """test file reader when some db does not exist.""" | ||||
| create_cv_mindrecord(4) | create_cv_mindrecord(4) | ||||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | ||||
| @@ -206,11 +199,6 @@ def test_lack_some_db(): | |||||
| assert '[MRMOpenError]: error_code: 1347690596, ' \ | assert '[MRMOpenError]: error_code: 1347690596, ' \ | ||||
| 'error_msg: MindRecord File could not open successfully.' \ | 'error_msg: MindRecord File could not open successfully.' \ | ||||
| in str(err.value) | in str(err.value) | ||||
| for x in paths: | |||||
| if os.path.exists("{}".format(x)): | |||||
| os.remove("{}".format(x)) | |||||
| if os.path.exists("{}.db".format(x)): | |||||
| os.remove("{}.db".format(x)) | |||||
| def test_invalid_mindrecord(): | def test_invalid_mindrecord(): | ||||
| @@ -225,8 +213,7 @@ def test_invalid_mindrecord(): | |||||
| in str(err.value) | in str(err.value) | ||||
| os.remove(CV_FILE_NAME) | os.remove(CV_FILE_NAME) | ||||
| def test_invalid_db(): | |||||
| def test_invalid_db(fixture_cv_file): | |||||
| """test file reader when the content of db is illegal.""" | """test file reader when the content of db is illegal.""" | ||||
| create_cv_mindrecord(1) | create_cv_mindrecord(1) | ||||
| os.remove("imagenet.mindrecord.db") | os.remove("imagenet.mindrecord.db") | ||||
| @@ -237,11 +224,8 @@ def test_invalid_db(): | |||||
| assert '[MRMOpenError]: error_code: 1347690596, ' \ | assert '[MRMOpenError]: error_code: 1347690596, ' \ | ||||
| 'error_msg: MindRecord File could not open successfully.' \ | 'error_msg: MindRecord File could not open successfully.' \ | ||||
| in str(err.value) | in str(err.value) | ||||
| os.remove("imagenet.mindrecord") | |||||
| os.remove("imagenet.mindrecord.db") | |||||
| def test_overwrite_invalid_mindrecord(): | |||||
| def test_overwrite_invalid_mindrecord(fixture_cv_file): | |||||
| """test file writer when overwrite invalid mindreocrd file.""" | """test file writer when overwrite invalid mindreocrd file.""" | ||||
| with open(CV_FILE_NAME, 'w') as f: | with open(CV_FILE_NAME, 'w') as f: | ||||
| f.write('just for test') | f.write('just for test') | ||||
| @@ -250,10 +234,8 @@ def test_overwrite_invalid_mindrecord(): | |||||
| assert '[MRMOpenError]: error_code: 1347690596, ' \ | assert '[MRMOpenError]: error_code: 1347690596, ' \ | ||||
| 'error_msg: MindRecord File could not open successfully.' \ | 'error_msg: MindRecord File could not open successfully.' \ | ||||
| in str(err.value) | in str(err.value) | ||||
| os.remove(CV_FILE_NAME) | |||||
| def test_overwrite_invalid_db(): | |||||
| def test_overwrite_invalid_db(fixture_cv_file): | |||||
| """test file writer when overwrite invalid db file.""" | """test file writer when overwrite invalid db file.""" | ||||
| with open('imagenet.mindrecord.db', 'w') as f: | with open('imagenet.mindrecord.db', 'w') as f: | ||||
| f.write('just for test') | f.write('just for test') | ||||
| @@ -261,11 +243,8 @@ def test_overwrite_invalid_db(): | |||||
| create_cv_mindrecord(1) | create_cv_mindrecord(1) | ||||
| assert '[MRMGenerateIndexError]: error_code: 1347690612, ' \ | assert '[MRMGenerateIndexError]: error_code: 1347690612, ' \ | ||||
| 'error_msg: Failed to generate index.' in str(err.value) | 'error_msg: Failed to generate index.' in str(err.value) | ||||
| os.remove("imagenet.mindrecord") | |||||
| os.remove("imagenet.mindrecord.db") | |||||
| def test_read_after_close(): | |||||
| def test_read_after_close(fixture_cv_file): | |||||
| """test file reader when close read.""" | """test file reader when close read.""" | ||||
| create_cv_mindrecord(1) | create_cv_mindrecord(1) | ||||
| reader = FileReader(CV_FILE_NAME) | reader = FileReader(CV_FILE_NAME) | ||||
| @@ -275,11 +254,8 @@ def test_read_after_close(): | |||||
| count = count + 1 | count = count + 1 | ||||
| logger.info("#item{}: {}".format(index, x)) | logger.info("#item{}: {}".format(index, x)) | ||||
| assert count == 0 | assert count == 0 | ||||
| os.remove(CV_FILE_NAME) | |||||
| os.remove("{}.db".format(CV_FILE_NAME)) | |||||
| def test_file_read_after_read(): | |||||
| def test_file_read_after_read(fixture_cv_file): | |||||
| """test file reader when finish read.""" | """test file reader when finish read.""" | ||||
| create_cv_mindrecord(1) | create_cv_mindrecord(1) | ||||
| reader = FileReader(CV_FILE_NAME) | reader = FileReader(CV_FILE_NAME) | ||||
| @@ -295,8 +271,6 @@ def test_file_read_after_read(): | |||||
| cnt = cnt + 1 | cnt = cnt + 1 | ||||
| logger.info("#item{}: {}".format(index, x)) | logger.info("#item{}: {}".format(index, x)) | ||||
| assert cnt == 0 | assert cnt == 0 | ||||
| os.remove(CV_FILE_NAME) | |||||
| os.remove("{}.db".format(CV_FILE_NAME)) | |||||
| def test_cv_file_writer_shard_num_greater_than_1000(): | def test_cv_file_writer_shard_num_greater_than_1000(): | ||||
| @@ -312,8 +286,7 @@ def test_add_index_without_add_schema(): | |||||
| fw.add_index(["label"]) | fw.add_index(["label"]) | ||||
| assert 'Failed to get meta info' in str(err.value) | assert 'Failed to get meta info' in str(err.value) | ||||
| def test_mindpage_pageno_pagesize_not_int(): | |||||
| def test_mindpage_pageno_pagesize_not_int(fixture_cv_file): | |||||
| """test page reader when some partition does not exist.""" | """test page reader when some partition does not exist.""" | ||||
| create_cv_mindrecord(4) | create_cv_mindrecord(4) | ||||
| reader = MindPage(CV_FILE_NAME + "0") | reader = MindPage(CV_FILE_NAME + "0") | ||||
| @@ -342,14 +315,8 @@ def test_mindpage_pageno_pagesize_not_int(): | |||||
| with pytest.raises(MRMFetchDataError, match="Failed to fetch data by category."): | with pytest.raises(MRMFetchDataError, match="Failed to fetch data by category."): | ||||
| reader.read_at_page_by_id(99999, 0, 1) | reader.read_at_page_by_id(99999, 0, 1) | ||||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | |||||
| for x in range(FILES_NUM)] | |||||
| for x in paths: | |||||
| os.remove("{}".format(x)) | |||||
| os.remove("{}.db".format(x)) | |||||
| def test_mindpage_filename_not_exist(): | |||||
| def test_mindpage_filename_not_exist(fixture_cv_file): | |||||
| """test page reader when some partition does not exist.""" | """test page reader when some partition does not exist.""" | ||||
| create_cv_mindrecord(4) | create_cv_mindrecord(4) | ||||
| reader = MindPage(CV_FILE_NAME + "0") | reader = MindPage(CV_FILE_NAME + "0") | ||||
| @@ -374,6 +341,3 @@ def test_mindpage_filename_not_exist(): | |||||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | ||||
| for x in range(FILES_NUM)] | for x in range(FILES_NUM)] | ||||
| for x in paths: | |||||
| os.remove("{}".format(x)) | |||||
| os.remove("{}.db".format(x)) | |||||
| @@ -14,6 +14,7 @@ | |||||
| """test mnist to mindrecord tool""" | """test mnist to mindrecord tool""" | ||||
| import cv2 | import cv2 | ||||
| import gzip | import gzip | ||||
| import pytest | |||||
| import numpy as np | import numpy as np | ||||
| import os | import os | ||||
| @@ -27,6 +28,34 @@ PARTITION_NUM = 4 | |||||
| IMAGE_SIZE = 28 | IMAGE_SIZE = 28 | ||||
| NUM_CHANNELS = 1 | NUM_CHANNELS = 1 | ||||
| @pytest.fixture | |||||
| def fixture_file(): | |||||
| """add/remove file""" | |||||
| def remove_one_file(x): | |||||
| if os.path.exists(x): | |||||
| os.remove(x) | |||||
| def remove_file(): | |||||
| x = "mnist_train.mindrecord" | |||||
| remove_one_file(x) | |||||
| x = "mnist_train.mindrecord.db" | |||||
| remove_one_file(x) | |||||
| x = "mnist_test.mindrecord" | |||||
| remove_one_file(x) | |||||
| x = "mnist_test.mindrecord.db" | |||||
| remove_one_file(x) | |||||
| for i in range(PARTITION_NUM): | |||||
| x = "mnist_train.mindrecord" + str(i) | |||||
| remove_one_file(x) | |||||
| x = "mnist_train.mindrecord" + str(i) + ".db" | |||||
| remove_one_file(x) | |||||
| x = "mnist_test.mindrecord" + str(i) | |||||
| remove_one_file(x) | |||||
| x = "mnist_test.mindrecord" + str(i) + ".db" | |||||
| remove_one_file(x) | |||||
| remove_file() | |||||
| yield "yield_fixture_data" | |||||
| remove_file() | |||||
| def read(train_name, test_name): | def read(train_name, test_name): | ||||
| """test file reader""" | """test file reader""" | ||||
| @@ -51,7 +80,7 @@ def read(train_name, test_name): | |||||
| reader.close() | reader.close() | ||||
| def test_mnist_to_mindrecord(): | |||||
| def test_mnist_to_mindrecord(fixture_file): | |||||
| """test transform mnist dataset to mindrecord.""" | """test transform mnist dataset to mindrecord.""" | ||||
| mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME) | mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME) | ||||
| mnist_transformer.transform() | mnist_transformer.transform() | ||||
| @@ -60,13 +89,7 @@ def test_mnist_to_mindrecord(): | |||||
| read("mnist_train.mindrecord", "mnist_test.mindrecord") | read("mnist_train.mindrecord", "mnist_test.mindrecord") | ||||
| os.remove("{}".format("mnist_train.mindrecord")) | |||||
| os.remove("{}.db".format("mnist_train.mindrecord")) | |||||
| os.remove("{}".format("mnist_test.mindrecord")) | |||||
| os.remove("{}.db".format("mnist_test.mindrecord")) | |||||
| def test_mnist_to_mindrecord_compare_data(): | |||||
| def test_mnist_to_mindrecord_compare_data(fixture_file): | |||||
| """test transform mnist dataset to mindrecord and compare data.""" | """test transform mnist dataset to mindrecord and compare data.""" | ||||
| mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME) | mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME) | ||||
| mnist_transformer.transform() | mnist_transformer.transform() | ||||
| @@ -121,21 +144,10 @@ def test_mnist_to_mindrecord_compare_data(): | |||||
| assert np.array(x['label']) == label | assert np.array(x['label']) == label | ||||
| reader.close() | reader.close() | ||||
| os.remove("{}".format("mnist_train.mindrecord")) | |||||
| os.remove("{}.db".format("mnist_train.mindrecord")) | |||||
| os.remove("{}".format("mnist_test.mindrecord")) | |||||
| os.remove("{}.db".format("mnist_test.mindrecord")) | |||||
| def test_mnist_to_mindrecord_multi_partition(): | |||||
| def test_mnist_to_mindrecord_multi_partition(fixture_file): | |||||
| """test transform mnist dataset to multiple mindrecord files.""" | """test transform mnist dataset to multiple mindrecord files.""" | ||||
| mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME, PARTITION_NUM) | mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME, PARTITION_NUM) | ||||
| mnist_transformer.transform() | mnist_transformer.transform() | ||||
| read("mnist_train.mindrecord0", "mnist_test.mindrecord0") | read("mnist_train.mindrecord0", "mnist_test.mindrecord0") | ||||
| for i in range(PARTITION_NUM): | |||||
| os.remove("{}".format("mnist_train.mindrecord" + str(i))) | |||||
| os.remove("{}.db".format("mnist_train.mindrecord" + str(i))) | |||||
| os.remove("{}".format("mnist_test.mindrecord" + str(i))) | |||||
| os.remove("{}.db".format("mnist_test.mindrecord" + str(i))) | |||||