Merge pull request !1317 from liyong126/mindrecord_compresstags/v0.3.0-alpha
| @@ -112,25 +112,26 @@ Status MindRecordOp::Init() { | |||
| data_schema_ = std::make_unique<DataSchema>(); | |||
| std::vector<std::shared_ptr<Schema>> schema_vec = shard_reader_->GetShardHeader()->GetSchemas(); | |||
| // check whether schema exists, if so use the first one | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!schema_vec.empty(), "No schema found"); | |||
| mindrecord::json mr_schema = schema_vec[0]->GetSchema()["schema"]; | |||
| std::vector<std::string> col_names = shard_reader_->get_shard_column()->GetColumnName(); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!col_names.empty(), "No schema found"); | |||
| std::vector<mindrecord::ColumnDataType> col_data_types = shard_reader_->get_shard_column()->GeColumnDataType(); | |||
| std::vector<std::vector<int64_t>> col_shapes = shard_reader_->get_shard_column()->GetColumnShape(); | |||
| bool load_all_cols = columns_to_load_.empty(); // if columns_to_load_ is empty it means load everything | |||
| std::map<std::string, int32_t> colname_to_ind; | |||
| for (mindrecord::json::iterator it = mr_schema.begin(); it != mr_schema.end(); ++it) { | |||
| std::string colname = it.key(); // key of the json, column name | |||
| mindrecord::json it_value = it.value(); // value, which contains type info and may contain shape | |||
| for (uint32_t i = 0; i < col_names.size(); i++) { | |||
| std::string colname = col_names[i]; | |||
| ColDescriptor col_desc; | |||
| TensorShape t_shape = TensorShape::CreateUnknownRankShape(); // shape of tensor, default unknown | |||
| std::string type_str = (it_value["type"] == "bytes" || it_value["type"] == "string") ? "uint8" : it_value["type"]; | |||
| std::string type_str = mindrecord::ColumnDataTypeNameNormalized[col_data_types[i]]; | |||
| DataType t_dtype = DataType(type_str); // valid types: {"bytes", "string", "int32", "int64", "float32", "float64"} | |||
| if (it_value["type"] == "bytes") { // rank = 1 | |||
| if (col_data_types[i] == mindrecord::ColumnBytes || col_data_types[i] == mindrecord::ColumnString) { // rank = 1 | |||
| col_desc = ColDescriptor(colname, t_dtype, TensorImpl::kFlexible, 1); | |||
| } else if (it_value.find("shape") != it_value.end()) { | |||
| std::vector<dsize_t> vec(it_value["shape"].size()); // temporary vector to hold shape | |||
| (void)std::copy(it_value["shape"].begin(), it_value["shape"].end(), vec.begin()); | |||
| } else if (col_shapes[i].size() > 0) { | |||
| std::vector<dsize_t> vec(col_shapes[i].size()); // temporary vector to hold shape | |||
| (void)std::copy(col_shapes[i].begin(), col_shapes[i].end(), vec.begin()); | |||
| t_shape = TensorShape(vec); | |||
| col_desc = ColDescriptor(colname, t_dtype, TensorImpl::kFlexible, t_shape.Rank(), &t_shape); | |||
| } else { // unknown shape | |||
| @@ -162,30 +163,7 @@ Status MindRecordOp::Init() { | |||
| num_rows_ = shard_reader_->GetNumRows(); | |||
| // Compute how many buffers we would need to accomplish rowsPerBuffer | |||
| buffers_needed_ = (num_rows_ + rows_per_buffer_ - 1) / rows_per_buffer_; | |||
| RETURN_IF_NOT_OK(SetColumnsBlob()); | |||
| return Status::OK(); | |||
| } | |||
| Status MindRecordOp::SetColumnsBlob() { | |||
| columns_blob_ = shard_reader_->GetBlobFields().second; | |||
| // get the exactly blob fields by columns_to_load_ | |||
| std::vector<std::string> columns_blob_exact; | |||
| for (auto &blob_field : columns_blob_) { | |||
| for (auto &column : columns_to_load_) { | |||
| if (column.compare(blob_field) == 0) { | |||
| columns_blob_exact.push_back(blob_field); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| columns_blob_index_ = std::vector<int32_t>(columns_to_load_.size(), -1); | |||
| int32_t iBlob = 0; | |||
| for (auto &blob_exact : columns_blob_exact) { | |||
| columns_blob_index_[column_name_id_map_[blob_exact]] = iBlob++; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -215,248 +193,18 @@ void MindRecordOp::Print(std::ostream &out, bool show_all) const { | |||
| } | |||
| } | |||
| template <typename T> | |||
| Status MindRecordOp::LoadFeature(std::shared_ptr<Tensor> *tensor, int32_t i_col, | |||
| const std::vector<uint8_t> &columns_blob, const mindrecord::json &columns_json) const { | |||
| TensorShape new_shape = TensorShape::CreateUnknownRankShape(); | |||
| const unsigned char *data = nullptr; | |||
| std::unique_ptr<T[]> array_data; | |||
| std::string string_data; | |||
| const ColDescriptor &cur_column = data_schema_->column(i_col); | |||
| std::string column_name = columns_to_load_[i_col]; | |||
| DataType type = cur_column.type(); | |||
| // load blob column | |||
| if (columns_blob_index_[i_col] >= 0 && columns_blob.size() > 0) { | |||
| int32_t pos = columns_blob_.size() == 1 ? -1 : columns_blob_index_[i_col]; | |||
| RETURN_IF_NOT_OK(LoadBlob(&new_shape, &data, columns_blob, pos, cur_column)); | |||
| } else { | |||
| switch (type.value()) { | |||
| case DataType::DE_UINT8: { | |||
| // For strings (Assume DE_UINT8 is reserved for strings) | |||
| RETURN_IF_NOT_OK(LoadByte(&new_shape, &string_data, column_name, columns_json)); | |||
| data = reinterpret_cast<const unsigned char *>(common::SafeCStr(string_data)); | |||
| break; | |||
| } | |||
| case DataType::DE_FLOAT32: { | |||
| // For both float scalars and arrays | |||
| RETURN_IF_NOT_OK(LoadFloat(&new_shape, &array_data, column_name, columns_json, cur_column, false)); | |||
| data = reinterpret_cast<const unsigned char *>(array_data.get()); | |||
| break; | |||
| } | |||
| case DataType::DE_FLOAT64: { | |||
| // For both double scalars and arrays | |||
| RETURN_IF_NOT_OK(LoadFloat(&new_shape, &array_data, column_name, columns_json, cur_column, true)); | |||
| data = reinterpret_cast<const unsigned char *>(array_data.get()); | |||
| break; | |||
| } | |||
| default: { | |||
| // For both integers scalars and arrays | |||
| RETURN_IF_NOT_OK(LoadInt(&new_shape, &array_data, column_name, columns_json, cur_column)); | |||
| data = reinterpret_cast<const unsigned char *>(array_data.get()); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| // Create Tensor with given details | |||
| RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, cur_column.tensorImpl(), new_shape, type, data)); | |||
| return Status::OK(); | |||
| } | |||
| Status MindRecordOp::LoadBlob(TensorShape *new_shape, const unsigned char **data, | |||
| const std::vector<uint8_t> &columns_blob, const int32_t pos, | |||
| const ColDescriptor &column) { | |||
| const auto kColumnSize = column.type().SizeInBytes(); | |||
| if (kColumnSize == 0) { | |||
| RETURN_STATUS_UNEXPECTED("column size is null"); | |||
| } | |||
| if (pos == -1) { | |||
| if (column.hasShape()) { | |||
| *new_shape = TensorShape::CreateUnknownRankShape(); | |||
| RETURN_IF_NOT_OK( | |||
| column.MaterializeTensorShape(static_cast<int32_t>(columns_blob.size() / kColumnSize), new_shape)); | |||
| } else { | |||
| std::vector<dsize_t> shapeDetails = {static_cast<dsize_t>(columns_blob.size() / kColumnSize)}; | |||
| *new_shape = TensorShape(shapeDetails); | |||
| } | |||
| *data = reinterpret_cast<const uint8_t *>(&(columns_blob[0])); | |||
| return Status::OK(); | |||
| } | |||
| auto uint64_from_bytes = [&](int64_t pos) { | |||
| uint64_t result = 0; | |||
| for (uint64_t n = 0; n < kInt64Len; n++) { | |||
| result = (result << 8) + columns_blob[pos + n]; | |||
| } | |||
| return result; | |||
| }; | |||
| uint64_t iStart = 0; | |||
| for (int32_t i = 0; i < pos; i++) { | |||
| uint64_t num_bytes = uint64_from_bytes(iStart); | |||
| iStart += kInt64Len + num_bytes; | |||
| } | |||
| uint64_t num_bytes = uint64_from_bytes(iStart); | |||
| iStart += kInt64Len; | |||
| if (column.hasShape()) { | |||
| *new_shape = TensorShape::CreateUnknownRankShape(); | |||
| RETURN_IF_NOT_OK(column.MaterializeTensorShape(static_cast<int32_t>(num_bytes / kColumnSize), new_shape)); | |||
| } else { | |||
| std::vector<dsize_t> shapeDetails = {static_cast<dsize_t>(num_bytes / kColumnSize)}; | |||
| *new_shape = TensorShape(shapeDetails); | |||
| } | |||
| *data = reinterpret_cast<const uint8_t *>(&(columns_blob[iStart])); | |||
| return Status::OK(); | |||
| } | |||
| template <typename T> | |||
| Status MindRecordOp::LoadFloat(TensorShape *new_shape, std::unique_ptr<T[]> *array_data, const std::string &column_name, | |||
| const mindrecord::json &columns_json, const ColDescriptor &column, bool use_double) { | |||
| if (!columns_json[column_name].is_array()) { | |||
| T value = 0; | |||
| RETURN_IF_NOT_OK(GetFloat(&value, columns_json[column_name], use_double)); | |||
| *new_shape = TensorShape::CreateScalar(); | |||
| *array_data = std::make_unique<T[]>(1); | |||
| (*array_data)[0] = value; | |||
| } else { | |||
| if (column.hasShape()) { | |||
| *new_shape = TensorShape(column.shape()); | |||
| } else { | |||
| std::vector<dsize_t> shapeDetails = {static_cast<dsize_t>(columns_json[column_name].size())}; | |||
| *new_shape = TensorShape(shapeDetails); | |||
| } | |||
| int idx = 0; | |||
| *array_data = std::make_unique<T[]>(new_shape->NumOfElements()); | |||
| for (auto &element : columns_json[column_name]) { | |||
| T value = 0; | |||
| RETURN_IF_NOT_OK(GetFloat(&value, element, use_double)); | |||
| (*array_data)[idx++] = value; | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| template <typename T> | |||
| Status MindRecordOp::GetFloat(T *value, const mindrecord::json &data, bool use_double) { | |||
| if (data.is_number()) { | |||
| *value = data; | |||
| } else if (data.is_string()) { | |||
| try { | |||
| if (use_double) { | |||
| *value = data.get<double>(); | |||
| } else { | |||
| *value = data.get<float>(); | |||
| } | |||
| } catch (mindrecord::json::exception &e) { | |||
| RETURN_STATUS_UNEXPECTED("Conversion to float failed."); | |||
| } | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Conversion to float failed."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| template <typename T> | |||
| Status MindRecordOp::LoadInt(TensorShape *new_shape, std::unique_ptr<T[]> *array_data, const std::string &column_name, | |||
| const mindrecord::json &columns_json, const ColDescriptor &column) { | |||
| if (!columns_json[column_name].is_array()) { | |||
| T value = 0; | |||
| RETURN_IF_NOT_OK(GetInt(&value, columns_json[column_name])); | |||
| *new_shape = TensorShape::CreateScalar(); | |||
| *array_data = std::make_unique<T[]>(1); | |||
| (*array_data)[0] = value; | |||
| } else { | |||
| if (column.hasShape()) { | |||
| *new_shape = TensorShape(column.shape()); | |||
| } else { | |||
| std::vector<dsize_t> shapeDetails = {static_cast<dsize_t>(columns_json[column_name].size())}; | |||
| *new_shape = TensorShape(shapeDetails); | |||
| } | |||
| int idx = 0; | |||
| *array_data = std::make_unique<T[]>(new_shape->NumOfElements()); | |||
| for (auto &element : columns_json[column_name]) { | |||
| T value = 0; | |||
| RETURN_IF_NOT_OK(GetInt(&value, element)); | |||
| (*array_data)[idx++] = value; | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| template <typename T> | |||
| Status MindRecordOp::GetInt(T *value, const mindrecord::json &data) { | |||
| int64_t temp_value = 0; | |||
| bool less_than_zero = false; | |||
| if (data.is_number_integer()) { | |||
| const mindrecord::json json_zero = 0; | |||
| if (data < json_zero) less_than_zero = true; | |||
| temp_value = data; | |||
| } else if (data.is_string()) { | |||
| std::string string_value = data; | |||
| if (!string_value.empty() && string_value[0] == '-') { | |||
| try { | |||
| temp_value = std::stoll(string_value); | |||
| less_than_zero = true; | |||
| } catch (std::invalid_argument &e) { | |||
| RETURN_STATUS_UNEXPECTED("Conversion to int failed, invalid argument."); | |||
| } catch (std::out_of_range &e) { | |||
| RETURN_STATUS_UNEXPECTED("Conversion to int failed, out of range."); | |||
| } | |||
| } else { | |||
| try { | |||
| temp_value = static_cast<int64_t>(std::stoull(string_value)); | |||
| } catch (std::invalid_argument &e) { | |||
| RETURN_STATUS_UNEXPECTED("Conversion to int failed, invalid argument."); | |||
| } catch (std::out_of_range &e) { | |||
| RETURN_STATUS_UNEXPECTED("Conversion to int failed, out of range."); | |||
| } | |||
| } | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Conversion to int failed."); | |||
| } | |||
| if ((less_than_zero && temp_value < static_cast<int64_t>(std::numeric_limits<T>::min())) || | |||
| (!less_than_zero && static_cast<uint64_t>(temp_value) > static_cast<uint64_t>(std::numeric_limits<T>::max()))) { | |||
| RETURN_STATUS_UNEXPECTED("Conversion to int failed. Out of range"); | |||
| } | |||
| *value = static_cast<T>(temp_value); | |||
| return Status::OK(); | |||
| } | |||
| Status MindRecordOp::LoadByte(TensorShape *new_shape, std::string *string_data, const std::string &column_name, | |||
| const mindrecord::json &columns_json) { | |||
| *string_data = columns_json[column_name]; | |||
| std::vector<dsize_t> shape_details = {static_cast<dsize_t>(string_data->size())}; | |||
| *new_shape = TensorShape(shape_details); | |||
| return Status::OK(); | |||
| } | |||
| Status MindRecordOp::WorkerEntry(int32_t worker_id) { | |||
| TaskManager::FindMe()->Post(); | |||
| std::unique_ptr<IOBlock> io_block; | |||
| RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); | |||
| while (io_block != nullptr) { | |||
| if (io_block->eoe() == true) { | |||
| if (io_block->eoe()) { | |||
| RETURN_IF_NOT_OK( | |||
| out_connector_->Add(worker_id, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE)))); | |||
| RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); | |||
| continue; | |||
| } | |||
| if (io_block->eof() == true) { | |||
| if (io_block->eof()) { | |||
| RETURN_IF_NOT_OK( | |||
| out_connector_->Add(worker_id, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF)))); | |||
| RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); | |||
| @@ -521,19 +269,10 @@ Status MindRecordOp::GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_bu | |||
| if (tupled_buffer.empty()) break; | |||
| } | |||
| for (const auto &tupled_row : tupled_buffer) { | |||
| std::vector<uint8_t> columnsBlob = std::get<0>(tupled_row); | |||
| std::vector<uint8_t> columns_blob = std::get<0>(tupled_row); | |||
| mindrecord::json columns_json = std::get<1>(tupled_row); | |||
| TensorRow tensor_row; | |||
| for (uint32_t j = 0; j < columns_to_load_.size(); ++j) { | |||
| std::shared_ptr<Tensor> tensor; | |||
| const ColDescriptor &cur_column = data_schema_->column(j); | |||
| DataType type = cur_column.type(); | |||
| RETURN_IF_NOT_OK(SwitchLoadFeature(type, &tensor, j, columnsBlob, columns_json)); | |||
| tensor_row.push_back(std::move(tensor)); | |||
| } | |||
| RETURN_IF_NOT_OK(LoadTensorRow(&tensor_row, columns_blob, columns_json)); | |||
| tensor_table->push_back(std::move(tensor_row)); | |||
| } | |||
| } | |||
| @@ -543,48 +282,46 @@ Status MindRecordOp::GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_bu | |||
| return Status::OK(); | |||
| } | |||
| Status MindRecordOp::SwitchLoadFeature(const DataType &type, std::shared_ptr<Tensor> *tensor, int32_t i_col, | |||
| const std::vector<uint8_t> &columns_blob, | |||
| const mindrecord::json &columns_json) const { | |||
| switch (type.value()) { | |||
| case DataType::DE_BOOL: { | |||
| return LoadFeature<bool>(tensor, i_col, columns_blob, columns_json); | |||
| } | |||
| case DataType::DE_INT8: { | |||
| return LoadFeature<int8_t>(tensor, i_col, columns_blob, columns_json); | |||
| } | |||
| case DataType::DE_UINT8: { | |||
| return LoadFeature<uint8_t>(tensor, i_col, columns_blob, columns_json); | |||
| } | |||
| case DataType::DE_INT16: { | |||
| return LoadFeature<int16_t>(tensor, i_col, columns_blob, columns_json); | |||
| } | |||
| case DataType::DE_UINT16: { | |||
| return LoadFeature<uint16_t>(tensor, i_col, columns_blob, columns_json); | |||
| } | |||
| case DataType::DE_INT32: { | |||
| return LoadFeature<int32_t>(tensor, i_col, columns_blob, columns_json); | |||
| } | |||
| case DataType::DE_UINT32: { | |||
| return LoadFeature<uint32_t>(tensor, i_col, columns_blob, columns_json); | |||
| } | |||
| case DataType::DE_INT64: { | |||
| return LoadFeature<int64_t>(tensor, i_col, columns_blob, columns_json); | |||
| } | |||
| case DataType::DE_UINT64: { | |||
| return LoadFeature<uint64_t>(tensor, i_col, columns_blob, columns_json); | |||
| } | |||
| case DataType::DE_FLOAT32: { | |||
| return LoadFeature<float>(tensor, i_col, columns_blob, columns_json); | |||
| } | |||
| case DataType::DE_FLOAT64: { | |||
| return LoadFeature<double>(tensor, i_col, columns_blob, columns_json); | |||
| Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vector<uint8_t> &columns_blob, | |||
| const mindrecord::json &columns_json) { | |||
| for (uint32_t i_col = 0; i_col < columns_to_load_.size(); i_col++) { | |||
| auto column_name = columns_to_load_[i_col]; | |||
| // Initialize column parameters | |||
| const unsigned char *data = nullptr; | |||
| std::unique_ptr<unsigned char[]> data_ptr; | |||
| uint64_t n_bytes = 0; | |||
| mindrecord::ColumnDataType column_data_type = mindrecord::ColumnNoDataType; | |||
| uint64_t column_data_type_size = 1; | |||
| std::vector<int64_t> column_shape; | |||
| // Get column data | |||
| auto has_column = shard_reader_->get_shard_column()->GetColumnValueByName( | |||
| column_name, columns_blob, columns_json, &data, &data_ptr, &n_bytes, &column_data_type, &column_data_type_size, | |||
| &column_shape); | |||
| if (has_column == MSRStatus::FAILED) { | |||
| RETURN_STATUS_UNEXPECTED("Failed to retrieve data from mindrecord reader."); | |||
| } | |||
| default: { | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, | |||
| "mindrecord column list type does not match any known types"); | |||
| std::shared_ptr<Tensor> tensor; | |||
| const ColDescriptor &column = data_schema_->column(i_col); | |||
| DataType type = column.type(); | |||
| // Set shape | |||
| auto num_elements = n_bytes / column_data_type_size; | |||
| if (column.hasShape()) { | |||
| auto new_shape = TensorShape(column.shape()); | |||
| RETURN_IF_NOT_OK(column.MaterializeTensorShape(static_cast<int32_t>(num_elements), &new_shape)); | |||
| RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, column.tensorImpl(), new_shape, type, data)); | |||
| } else { | |||
| std::vector<dsize_t> shapeDetails = {static_cast<dsize_t>(num_elements)}; | |||
| auto new_shape = TensorShape(shapeDetails); | |||
| RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, column.tensorImpl(), new_shape, type, data)); | |||
| } | |||
| tensor_row->push_back(std::move(tensor)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status MindRecordOp::FetchBlockBuffer(const int32_t &buffer_id) { | |||
| @@ -23,6 +23,7 @@ | |||
| #include <queue> | |||
| #include <string> | |||
| #include <tuple> | |||
| #include <unordered_map> | |||
| #include <unordered_set> | |||
| #include <vector> | |||
| @@ -31,6 +32,7 @@ | |||
| #include "dataset/engine/datasetops/source/io_block.h" | |||
| #include "dataset/util/queue.h" | |||
| #include "dataset/util/status.h" | |||
| #include "mindrecord/include/shard_column.h" | |||
| #include "mindrecord/include/shard_error.h" | |||
| #include "mindrecord/include/shard_reader.h" | |||
| #include "mindrecord/include/common/shard_utils.h" | |||
| @@ -193,8 +195,6 @@ class MindRecordOp : public ParallelOp { | |||
| Status Init(); | |||
| Status SetColumnsBlob(); | |||
| // Base-class override for NodePass visitor acceptor. | |||
| // @param p - Pointer to the NodePass to be accepted. | |||
| // @param modified - Whether this node visit modified the pipeline. | |||
| @@ -205,56 +205,11 @@ class MindRecordOp : public ParallelOp { | |||
| Status GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_buffer, int64_t buffer_id, int32_t worker_id); | |||
| // Parses a single cell and puts the data into a tensor | |||
| // @param tensor - the tensor to put the parsed data in | |||
| // @param i_col - the id of column to parse | |||
| // @param tensor_row - the tensor row to put the parsed data in | |||
| // @param columns_blob - the blob data received from the reader | |||
| // @param columns_json - the data for fields received from the reader | |||
| template <typename T> | |||
| Status LoadFeature(std::shared_ptr<Tensor> *tensor, int32_t i_col, const std::vector<uint8_t> &columns_blob, | |||
| const mindrecord::json &columns_json) const; | |||
| Status SwitchLoadFeature(const DataType &type, std::shared_ptr<Tensor> *tensor, int32_t i_col, | |||
| const std::vector<uint8_t> &columns_blob, const mindrecord::json &columns_json) const; | |||
| static Status LoadBlob(TensorShape *new_shape, const unsigned char **data, const std::vector<uint8_t> &columns_blob, | |||
| const int32_t pos, const ColDescriptor &column); | |||
| // Get shape and data (scalar or array) for tensor to be created (for floats and doubles) | |||
| // @param new_shape - the shape of tensor to be created. | |||
| // @param array_data - the array where data should be put in | |||
| // @param column_name - name of current column to be processed | |||
| // @param columns_json - the data for fields received from the reader | |||
| // @param column - description of current column from schema | |||
| // @param use_double - boolean to choose between float32 and float64 | |||
| template <typename T> | |||
| static Status LoadFloat(TensorShape *new_shape, std::unique_ptr<T[]> *array_data, const std::string &column_name, | |||
| const mindrecord::json &columns_json, const ColDescriptor &column, bool use_double); | |||
| // Get shape and data (scalar or array) for tensor to be created (for integers) | |||
| // @param new_shape - the shape of tensor to be created. | |||
| // @param array_data - the array where data should be put in | |||
| // @param column_name - name of current column to be processed | |||
| // @param columns_json - the data for fields received from the reader | |||
| // @param column - description of current column from schema | |||
| template <typename T> | |||
| static Status LoadInt(TensorShape *new_shape, std::unique_ptr<T[]> *array_data, const std::string &column_name, | |||
| const mindrecord::json &columns_json, const ColDescriptor &column); | |||
| static Status LoadByte(TensorShape *new_shape, std::string *string_data, const std::string &column_name, | |||
| const mindrecord::json &columns_json); | |||
| // Get a single float value from the given json | |||
| // @param value - the float to put the value in | |||
| // @param arrayData - the given json containing the float | |||
| // @param use_double - boolean to choose between float32 and float64 | |||
| template <typename T> | |||
| static Status GetFloat(T *value, const mindrecord::json &data, bool use_double); | |||
| // Get a single integer value from the given json | |||
| // @param value - the integer to put the value in | |||
| // @param arrayData - the given json containing the integer | |||
| template <typename T> | |||
| static Status GetInt(T *value, const mindrecord::json &data); | |||
| Status LoadTensorRow(TensorRow *tensor_row, const std::vector<uint8_t> &columns_blob, | |||
| const mindrecord::json &columns_json); | |||
| Status FetchBlockBuffer(const int32_t &buffer_id); | |||
| @@ -91,8 +91,8 @@ void BindShardReader(const py::module *m) { | |||
| .def("launch", &ShardReader::Launch) | |||
| .def("get_header", &ShardReader::GetShardHeader) | |||
| .def("get_blob_fields", &ShardReader::GetBlobFields) | |||
| .def("get_next", | |||
| (std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>(ShardReader::*)()) & ShardReader::GetNextPy) | |||
| .def("get_next", (std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>>(ShardReader::*)()) & | |||
| ShardReader::GetNextPy) | |||
| .def("finish", &ShardReader::Finish) | |||
| .def("close", &ShardReader::Close); | |||
| } | |||
| @@ -65,6 +65,9 @@ const int kUnsignedInt4 = 4; | |||
| enum LabelCategory { kSchemaLabel, kStatisticsLabel, kIndexLabel }; | |||
| const char kVersion[] = "3.0"; | |||
| const std::vector<std::string> kSupportedVersion = {"2.0", kVersion}; | |||
| enum ShardType { | |||
| kNLP = 0, | |||
| kCV = 1, | |||
| @@ -0,0 +1,163 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDRECORD_INCLUDE_SHARD_COLUMN_H_ | |||
| #define MINDRECORD_INCLUDE_SHARD_COLUMN_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "mindrecord/include/shard_header.h" | |||
| namespace mindspore { | |||
| namespace mindrecord { | |||
| const uint64_t kUnsignedOne = 1; | |||
| const uint64_t kBitsOfByte = 8; | |||
| const uint64_t kDataTypeBits = 2; | |||
| const uint64_t kNumDataOfByte = 4; | |||
| const uint64_t kBytesOfColumnLen = 4; | |||
| const uint64_t kDataTypeBitMask = 3; | |||
| const uint64_t kDataTypes = 6; | |||
| enum IntegerType { kInt8Type = 0, kInt16Type, kInt32Type, kInt64Type }; | |||
| enum ColumnCategory { ColumnInRaw, ColumnInBlob, ColumnNotFound }; | |||
| enum ColumnDataType { | |||
| ColumnBytes = 0, | |||
| ColumnString = 1, | |||
| ColumnInt32 = 2, | |||
| ColumnInt64 = 3, | |||
| ColumnFloat32 = 4, | |||
| ColumnFloat64 = 5, | |||
| ColumnNoDataType = 6 | |||
| }; | |||
| // mapping as {"bytes", "string", "int32", "int64", "float32", "float64"}; | |||
| const uint32_t ColumnDataTypeSize[kDataTypes] = {1, 1, 4, 8, 4, 8}; | |||
| const std::vector<std::string> ColumnDataTypeNameNormalized = {"uint8", "uint8", "int32", | |||
| "int64", "float32", "float64"}; | |||
| const std::unordered_map<std::string, ColumnDataType> ColumnDataTypeMap = { | |||
| {"bytes", ColumnBytes}, {"string", ColumnString}, {"int32", ColumnInt32}, | |||
| {"int64", ColumnInt64}, {"float32", ColumnFloat32}, {"float64", ColumnFloat64}}; | |||
| class ShardColumn { | |||
| public: | |||
| explicit ShardColumn(const std::shared_ptr<ShardHeader> &shard_header, bool compress_integer = true); | |||
| ~ShardColumn() = default; | |||
| /// \brief get column value by column name | |||
| MSRStatus GetColumnValueByName(const std::string &column_name, const std::vector<uint8_t> &columns_blob, | |||
| const json &columns_json, const unsigned char **data, | |||
| std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes, | |||
| ColumnDataType *column_data_type, uint64_t *column_data_type_size, | |||
| std::vector<int64_t> *column_shape); | |||
| /// \brief compress blob | |||
| std::vector<uint8_t> CompressBlob(const std::vector<uint8_t> &blob); | |||
| /// \brief check if blob compressed | |||
| bool CheckCompressBlob() const { return has_compress_blob_; } | |||
| uint64_t GetNumBlobColumn() const { return num_blob_column_; } | |||
| std::vector<std::string> GetColumnName() { return column_name_; } | |||
| std::vector<ColumnDataType> GeColumnDataType() { return column_data_type_; } | |||
| std::vector<std::vector<int64_t>> GetColumnShape() { return column_shape_; } | |||
| /// \brief get column value from blob | |||
| MSRStatus GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob, | |||
| const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr, | |||
| uint64_t *n_bytes); | |||
| private: | |||
| /// \brief get column value from json | |||
| MSRStatus GetColumnFromJson(const std::string &column_name, const json &columns_json, | |||
| std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes); | |||
| /// \brief get float value from json | |||
| template <typename T> | |||
| MSRStatus GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value, bool use_double); | |||
| /// \brief get integer value from json | |||
| template <typename T> | |||
| MSRStatus GetInt(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value); | |||
| /// \brief get column offset address and size from blob | |||
| MSRStatus GetColumnAddressInBlock(const uint64_t &column_id, const std::vector<uint8_t> &columns_blob, | |||
| uint64_t *num_bytes, uint64_t *shift_idx); | |||
| /// \brief check if column name is available | |||
| ColumnCategory CheckColumnName(const std::string &column_name); | |||
| /// \brief compress integer column | |||
| static vector<uint8_t> CompressInt(const vector<uint8_t> &src_bytes, const IntegerType &int_type); | |||
| /// \brief uncompress integer array column | |||
| template <typename T> | |||
| static MSRStatus UncompressInt(const uint64_t &column_id, std::unique_ptr<unsigned char[]> *data_ptr, | |||
| const std::vector<uint8_t> &columns_blob, uint64_t *num_bytes, uint64_t shift_idx); | |||
| /// \brief convert big-endian bytes to unsigned int | |||
| /// \param bytes_array bytes array | |||
| /// \param pos shift address in bytes array | |||
| /// \param i_type integer type | |||
| /// \return unsigned int | |||
| static uint64_t BytesBigToUInt64(const std::vector<uint8_t> &bytes_array, const uint64_t &pos, | |||
| const IntegerType &i_type); | |||
| /// \brief convert unsigned int to big-endian bytes | |||
| /// \param value integer value | |||
| /// \param i_type integer type | |||
| /// \return bytes | |||
| static std::vector<uint8_t> UIntToBytesBig(uint64_t value, const IntegerType &i_type); | |||
| /// \brief convert unsigned int to little-endian bytes | |||
| /// \param value integer value | |||
| /// \param i_type integer type | |||
| /// \return bytes | |||
| static std::vector<uint8_t> UIntToBytesLittle(uint64_t value, const IntegerType &i_type); | |||
| /// \brief convert unsigned int to little-endian bytes | |||
| /// \param bytes_array bytes array | |||
| /// \param pos shift address in bytes array | |||
| /// \param src_i_type source integer typ0e | |||
| /// \param dst_i_type (output), destination integer type | |||
| /// \return integer | |||
| static int64_t BytesLittleToMinIntType(const std::vector<uint8_t> &bytes_array, const uint64_t &pos, | |||
| const IntegerType &src_i_type, IntegerType *dst_i_type = nullptr); | |||
| private: | |||
| std::vector<std::string> column_name_; // column name list | |||
| std::vector<ColumnDataType> column_data_type_; // column data type list | |||
| std::vector<std::vector<int64_t>> column_shape_; // column shape list | |||
| std::unordered_map<string, uint64_t> column_name_id_; // column name id map | |||
| std::vector<std::string> blob_column_; // blob column list | |||
| std::unordered_map<std::string, uint64_t> blob_column_id_; // blob column name id map | |||
| bool has_compress_blob_; // if has compress blob | |||
| uint64_t num_blob_column_; // number of blob columns | |||
| }; | |||
| } // namespace mindrecord | |||
| } // namespace mindspore | |||
| #endif // MINDRECORD_INCLUDE_SHARD_COLUMN_H_ | |||
| @@ -118,8 +118,6 @@ class ShardHeader { | |||
| void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; } | |||
| const string GetVersion() { return version_; } | |||
| std::vector<std::string> SerializeHeader(); | |||
| MSRStatus PagesToFile(const std::string dump_file_name); | |||
| @@ -175,7 +173,6 @@ class ShardHeader { | |||
| uint32_t shard_count_; | |||
| uint64_t header_size_; | |||
| uint64_t page_size_; | |||
| string version_ = "2.0"; | |||
| std::shared_ptr<Index> index_; | |||
| std::vector<std::string> shard_addresses_; | |||
| @@ -43,6 +43,7 @@ | |||
| #include <vector> | |||
| #include "mindrecord/include/common/shard_utils.h" | |||
| #include "mindrecord/include/shard_category.h" | |||
| #include "mindrecord/include/shard_column.h" | |||
| #include "mindrecord/include/shard_error.h" | |||
| #include "mindrecord/include/shard_index_generator.h" | |||
| #include "mindrecord/include/shard_operator.h" | |||
| @@ -111,6 +112,10 @@ class ShardReader { | |||
| /// \return the metadata | |||
| std::shared_ptr<ShardHeader> GetShardHeader() const; | |||
| /// \brief aim to get columns context | |||
| /// \return the columns | |||
| std::shared_ptr<ShardColumn> get_shard_column() const; | |||
| /// \brief get the number of shards | |||
| /// \return # of shards | |||
| int GetShardCount() const; | |||
| @@ -185,7 +190,7 @@ class ShardReader { | |||
| /// \brief return a batch, given that one is ready, python API | |||
| /// \return a batch of images and image data | |||
| std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>> GetNextPy(); | |||
| std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>> GetNextPy(); | |||
| /// \brief get blob filed list | |||
| /// \return blob field list | |||
| @@ -295,16 +300,18 @@ class ShardReader { | |||
| /// \brief get number of classes | |||
| int64_t GetNumClasses(const std::string &category_field); | |||
| /// \brief get meta of header | |||
| std::pair<MSRStatus, std::vector<std::string>> GetMeta(const std::string &file_path, json &meta_data); | |||
| /// \brief get exactly blob fields data by indices | |||
| std::vector<uint8_t> ExtractBlobFieldBySelectColumns(std::vector<uint8_t> &blob_fields_bytes, | |||
| std::vector<uint32_t> &ordered_selected_columns_index); | |||
| /// \brief extract uncompressed data based on column list | |||
| std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> UnCompressBlob(const std::vector<uint8_t> &raw_blob_data); | |||
| protected: | |||
| uint64_t header_size_; // header size | |||
| uint64_t page_size_; // page size | |||
| int shard_count_; // number of shards | |||
| std::shared_ptr<ShardHeader> shard_header_; // shard header | |||
| std::shared_ptr<ShardColumn> shard_column_; // shard column | |||
| std::vector<sqlite3 *> database_paths_; // sqlite handle list | |||
| std::vector<string> file_paths_; // file paths | |||
| @@ -36,6 +36,7 @@ | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "mindrecord/include/common/shard_utils.h" | |||
| #include "mindrecord/include/shard_column.h" | |||
| #include "mindrecord/include/shard_error.h" | |||
| #include "mindrecord/include/shard_header.h" | |||
| #include "mindrecord/include/shard_index.h" | |||
| @@ -242,7 +243,8 @@ class ShardWriter { | |||
| std::vector<std::string> file_paths_; // file paths | |||
| std::vector<std::shared_ptr<std::fstream>> file_streams_; // file handles | |||
| std::shared_ptr<ShardHeader> shard_header_; // shard headers | |||
| std::shared_ptr<ShardHeader> shard_header_; // shard header | |||
| std::shared_ptr<ShardColumn> shard_column_; // shard columns | |||
| std::map<uint64_t, std::map<int, std::string>> err_mg_; // used for storing error raw_data info | |||
| @@ -133,6 +133,12 @@ MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool loa | |||
| shard_header_ = std::make_shared<ShardHeader>(sh); | |||
| header_size_ = shard_header_->GetHeaderSize(); | |||
| page_size_ = shard_header_->GetPageSize(); | |||
| // version < 3.0 | |||
| if (first_meta_data["version"] < kVersion) { | |||
| shard_column_ = std::make_shared<ShardColumn>(shard_header_, false); | |||
| } else { | |||
| shard_column_ = std::make_shared<ShardColumn>(shard_header_, true); | |||
| } | |||
| num_rows_ = 0; | |||
| auto row_group_summary = ReadRowGroupSummary(); | |||
| for (const auto &rg : row_group_summary) { | |||
| @@ -226,6 +232,8 @@ void ShardReader::Close() { | |||
| std::shared_ptr<ShardHeader> ShardReader::GetShardHeader() const { return shard_header_; } | |||
| std::shared_ptr<ShardColumn> ShardReader::get_shard_column() const { return shard_column_; } | |||
| int ShardReader::GetShardCount() const { return shard_header_->GetShardCount(); } | |||
| int ShardReader::GetNumRows() const { return num_rows_; } | |||
| @@ -1059,36 +1067,6 @@ MSRStatus ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, u | |||
| return SUCCESS; | |||
| } | |||
| std::vector<uint8_t> ShardReader::ExtractBlobFieldBySelectColumns( | |||
| std::vector<uint8_t> &blob_fields_bytes, std::vector<uint32_t> &ordered_selected_columns_index) { | |||
| std::vector<uint8_t> exactly_blob_fields_bytes; | |||
| auto uint64_from_bytes = [&](int64_t pos) { | |||
| uint64_t result = 0; | |||
| for (uint64_t n = 0; n < kInt64Len; n++) { | |||
| result = (result << 8) + blob_fields_bytes[pos + n]; | |||
| } | |||
| return result; | |||
| }; | |||
| // get the exactly blob fields | |||
| uint32_t current_index = 0; | |||
| uint64_t current_offset = 0; | |||
| uint64_t data_len = uint64_from_bytes(current_offset); | |||
| while (current_offset < blob_fields_bytes.size()) { | |||
| if (std::any_of(ordered_selected_columns_index.begin(), ordered_selected_columns_index.end(), | |||
| [¤t_index](uint32_t &index) { return index == current_index; })) { | |||
| exactly_blob_fields_bytes.insert(exactly_blob_fields_bytes.end(), blob_fields_bytes.begin() + current_offset, | |||
| blob_fields_bytes.begin() + current_offset + kInt64Len + data_len); | |||
| } | |||
| current_index++; | |||
| current_offset += kInt64Len + data_len; | |||
| data_len = uint64_from_bytes(current_offset); | |||
| } | |||
| return exactly_blob_fields_bytes; | |||
| } | |||
| TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_id) { | |||
| // All tasks are done | |||
| if (task_id >= static_cast<int>(tasks_.Size())) { | |||
| @@ -1126,40 +1104,10 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_ | |||
| return std::make_pair(FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>()); | |||
| } | |||
| // extract the exactly blob bytes by selected columns | |||
| std::vector<uint8_t> images_with_exact_columns; | |||
| if (selected_columns_.size() == 0) { | |||
| images_with_exact_columns = images; | |||
| } else { | |||
| auto blob_fields = GetBlobFields(); | |||
| std::vector<uint32_t> ordered_selected_columns_index; | |||
| uint32_t index = 0; | |||
| for (auto &blob_field : blob_fields.second) { | |||
| for (auto &field : selected_columns_) { | |||
| if (field.compare(blob_field) == 0) { | |||
| ordered_selected_columns_index.push_back(index); | |||
| break; | |||
| } | |||
| } | |||
| index++; | |||
| } | |||
| if (ordered_selected_columns_index.size() != 0) { | |||
| // extract the images | |||
| if (blob_fields.second.size() == 1) { | |||
| if (ordered_selected_columns_index.size() == 1) { | |||
| images_with_exact_columns = images; | |||
| } | |||
| } else { | |||
| images_with_exact_columns = ExtractBlobFieldBySelectColumns(images, ordered_selected_columns_index); | |||
| } | |||
| } | |||
| } | |||
| // Deliver batch data to output map | |||
| std::vector<std::tuple<std::vector<uint8_t>, json>> batch; | |||
| batch.emplace_back(std::move(images_with_exact_columns), std::move(std::get<2>(task))); | |||
| batch.emplace_back(std::move(images), std::move(std::get<2>(task))); | |||
| return std::make_pair(SUCCESS, std::move(batch)); | |||
| } | |||
| @@ -1369,16 +1317,41 @@ std::vector<std::tuple<std::vector<uint8_t>, json>> ShardReader::GetNextById(con | |||
| return std::move(ret.second); | |||
| } | |||
| std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>> ShardReader::GetNextPy() { | |||
| std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> ShardReader::UnCompressBlob( | |||
| const std::vector<uint8_t> &raw_blob_data) { | |||
| auto loaded_columns = selected_columns_.size() == 0 ? shard_column_->GetColumnName() : selected_columns_; | |||
| auto blob_fields = GetBlobFields().second; | |||
| std::vector<std::vector<uint8_t>> blob_data; | |||
| for (uint32_t i_col = 0; i_col < loaded_columns.size(); ++i_col) { | |||
| if (std::find(blob_fields.begin(), blob_fields.end(), loaded_columns[i_col]) == blob_fields.end()) continue; | |||
| const unsigned char *data = nullptr; | |||
| std::unique_ptr<unsigned char[]> data_ptr; | |||
| uint64_t n_bytes = 0; | |||
| auto ret = shard_column_->GetColumnFromBlob(loaded_columns[i_col], raw_blob_data, &data, &data_ptr, &n_bytes); | |||
| if (ret != SUCCESS) { | |||
| MS_LOG(ERROR) << "Error when get data from blob, column name is " << loaded_columns[i_col] << "."; | |||
| return {FAILED, std::vector<std::vector<uint8_t>>(blob_fields.size(), std::vector<uint8_t>())}; | |||
| } | |||
| if (data == nullptr) { | |||
| data = reinterpret_cast<const unsigned char *>(data_ptr.get()); | |||
| } | |||
| std::vector<uint8_t> column(data, data + (n_bytes / sizeof(unsigned char))); | |||
| blob_data.push_back(column); | |||
| } | |||
| return {SUCCESS, blob_data}; | |||
| } | |||
| std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>> ShardReader::GetNextPy() { | |||
| auto res = GetNext(); | |||
| vector<std::tuple<std::vector<uint8_t>, pybind11::object>> jsonData; | |||
| std::transform(res.begin(), res.end(), std::back_inserter(jsonData), | |||
| [](const std::tuple<std::vector<uint8_t>, json> &item) { | |||
| vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>> data; | |||
| std::transform(res.begin(), res.end(), std::back_inserter(data), | |||
| [this](const std::tuple<std::vector<uint8_t>, json> &item) { | |||
| auto &j = std::get<1>(item); | |||
| pybind11::object obj = nlohmann::detail::FromJsonImpl(j); | |||
| return std::make_tuple(std::get<0>(item), std::move(obj)); | |||
| auto ret = UnCompressBlob(std::get<0>(item)); | |||
| return std::make_tuple(ret.second, std::move(obj)); | |||
| }); | |||
| return jsonData; | |||
| return data; | |||
| } | |||
| void ShardReader::Reset() { | |||
| @@ -206,6 +206,7 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) { | |||
| MS_LOG(ERROR) << "Open file failed"; | |||
| return FAILED; | |||
| } | |||
| shard_column_ = std::make_shared<ShardColumn>(shard_header_); | |||
| return SUCCESS; | |||
| } | |||
| @@ -271,6 +272,7 @@ MSRStatus ShardWriter::SetShardHeader(std::shared_ptr<ShardHeader> header_data) | |||
| shard_header_ = header_data; | |||
| shard_header_->SetHeaderSize(header_size_); | |||
| shard_header_->SetPageSize(page_size_); | |||
| shard_column_ = std::make_shared<ShardColumn>(shard_header_); | |||
| return SUCCESS; | |||
| } | |||
| @@ -608,6 +610,14 @@ MSRStatus ShardWriter::WriteRawDataPreCheck(std::map<uint64_t, std::vector<json> | |||
| MS_LOG(ERROR) << "IO error / there is no free disk to be used"; | |||
| return FAILED; | |||
| } | |||
| // compress blob | |||
| if (shard_column_->CheckCompressBlob()) { | |||
| for (auto &blob : blob_data) { | |||
| blob = shard_column_->CompressBlob(blob); | |||
| } | |||
| } | |||
| // Add 4-bytes dummy blob data if no any blob fields | |||
| if (blob_data.size() == 0 && raw_data.size() > 0) { | |||
| blob_data = std::vector<std::vector<uint8_t>>(raw_data[0].size(), std::vector<uint8_t>(kUnsignedInt4, 0)); | |||
| @@ -0,0 +1,473 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "mindrecord/include/shard_column.h" | |||
| #include "common/utils.h" | |||
| #include "mindrecord/include/common/shard_utils.h" | |||
| #include "mindrecord/include/shard_error.h" | |||
| namespace mindspore { | |||
| namespace mindrecord { | |||
| ShardColumn::ShardColumn(const std::shared_ptr<ShardHeader> &shard_header, bool compress_integer) { | |||
| auto first_schema = shard_header->GetSchemas()[0]; | |||
| auto schema = first_schema->GetSchema()["schema"]; | |||
| bool has_integer_array = false; | |||
| for (json::iterator it = schema.begin(); it != schema.end(); ++it) { | |||
| const std::string &column_name = it.key(); | |||
| column_name_.push_back(column_name); | |||
| json it_value = it.value(); | |||
| std::string str_type = it_value["type"]; | |||
| column_data_type_.push_back(ColumnDataTypeMap.at(str_type)); | |||
| if (it_value.find("shape") != it_value.end()) { | |||
| std::vector<int64_t> vec(it_value["shape"].size()); | |||
| std::copy(it_value["shape"].begin(), it_value["shape"].end(), vec.begin()); | |||
| column_shape_.push_back(vec); | |||
| if (str_type == "int32" || str_type == "int64") { | |||
| has_integer_array = true; | |||
| } | |||
| } else { | |||
| std::vector<int64_t> vec = {}; | |||
| column_shape_.push_back(vec); | |||
| } | |||
| } | |||
| for (uint64_t i = 0; i < column_name_.size(); i++) { | |||
| column_name_id_[column_name_[i]] = i; | |||
| } | |||
| auto blob_fields = first_schema->GetBlobFields(); | |||
| for (const auto &field : blob_fields) { | |||
| blob_column_.push_back(field); | |||
| } | |||
| for (uint64_t i = 0; i < blob_column_.size(); i++) { | |||
| blob_column_id_[blob_column_[i]] = i; | |||
| } | |||
| has_compress_blob_ = (compress_integer && has_integer_array); | |||
| num_blob_column_ = blob_column_.size(); | |||
| } | |||
| MSRStatus ShardColumn::GetColumnValueByName(const std::string &column_name, const std::vector<uint8_t> &columns_blob, | |||
| const json &columns_json, const unsigned char **data, | |||
| std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes, | |||
| ColumnDataType *column_data_type, uint64_t *column_data_type_size, | |||
| std::vector<int64_t> *column_shape) { | |||
| // Skip if column not found | |||
| auto column_category = CheckColumnName(column_name); | |||
| if (column_category == ColumnNotFound) { | |||
| return FAILED; | |||
| } | |||
| // Get data type and size | |||
| auto column_id = column_name_id_[column_name]; | |||
| *column_data_type = column_data_type_[column_id]; | |||
| *column_data_type_size = ColumnDataTypeSize[*column_data_type]; | |||
| *column_shape = column_shape_[column_id]; | |||
| // Retrieve value from json | |||
| if (column_category == ColumnInRaw) { | |||
| if (GetColumnFromJson(column_name, columns_json, data_ptr, n_bytes) == FAILED) { | |||
| MS_LOG(ERROR) << "Error when get data from json, column name is " << column_name << "."; | |||
| return FAILED; | |||
| } | |||
| *data = reinterpret_cast<const unsigned char *>(data_ptr->get()); | |||
| return SUCCESS; | |||
| } | |||
| // Retrieve value from blob | |||
| if (GetColumnFromBlob(column_name, columns_blob, data, data_ptr, n_bytes) == FAILED) { | |||
| MS_LOG(ERROR) << "Error when get data from blob, column name is " << column_name << "."; | |||
| return FAILED; | |||
| } | |||
| if (*data == nullptr) { | |||
| *data = reinterpret_cast<const unsigned char *>(data_ptr->get()); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| MSRStatus ShardColumn::GetColumnFromJson(const std::string &column_name, const json &columns_json, | |||
| std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes) { | |||
| auto column_id = column_name_id_[column_name]; | |||
| auto column_data_type = column_data_type_[column_id]; | |||
| // Initialize num bytes | |||
| *n_bytes = ColumnDataTypeSize[column_data_type]; | |||
| auto json_column_value = columns_json[column_name]; | |||
| switch (column_data_type) { | |||
| case ColumnFloat32: { | |||
| return GetFloat<float>(data_ptr, json_column_value, false); | |||
| } | |||
| case ColumnFloat64: { | |||
| return GetFloat<double>(data_ptr, json_column_value, true); | |||
| } | |||
| case ColumnInt32: { | |||
| return GetInt<int32_t>(data_ptr, json_column_value); | |||
| } | |||
| case ColumnInt64: { | |||
| return GetInt<int64_t>(data_ptr, json_column_value); | |||
| } | |||
| default: { | |||
| // Convert string to c_str | |||
| std::string tmp_string = json_column_value; | |||
| *n_bytes = tmp_string.size(); | |||
| auto data = reinterpret_cast<const unsigned char *>(common::SafeCStr(tmp_string)); | |||
| *data_ptr = std::make_unique<unsigned char[]>(*n_bytes); | |||
| for (uint32_t i = 0; i < *n_bytes; i++) { | |||
| (*data_ptr)[i] = *(data + i); | |||
| } | |||
| break; | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| template <typename T> | |||
| MSRStatus ShardColumn::GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value, | |||
| bool use_double) { | |||
| std::unique_ptr<T[]> array_data = std::make_unique<T[]>(1); | |||
| if (!json_column_value.is_string() && !json_column_value.is_number()) { | |||
| MS_LOG(ERROR) << "Conversion to float failed (" << json_column_value << ")."; | |||
| return FAILED; | |||
| } | |||
| if (json_column_value.is_number()) { | |||
| array_data[0] = json_column_value; | |||
| } else { | |||
| // Convert string to float | |||
| try { | |||
| if (use_double) { | |||
| array_data[0] = json_column_value.get<double>(); | |||
| } else { | |||
| array_data[0] = json_column_value.get<float>(); | |||
| } | |||
| } catch (json::exception &e) { | |||
| MS_LOG(ERROR) << "Conversion to float failed (" << json_column_value << ")."; | |||
| return FAILED; | |||
| } | |||
| } | |||
| auto data = reinterpret_cast<const unsigned char *>(array_data.get()); | |||
| *data_ptr = std::make_unique<unsigned char[]>(sizeof(T)); | |||
| for (uint32_t i = 0; i < sizeof(T); i++) { | |||
| (*data_ptr)[i] = *(data + i); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| template <typename T> | |||
| MSRStatus ShardColumn::GetInt(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value) { | |||
| std::unique_ptr<T[]> array_data = std::make_unique<T[]>(1); | |||
| int64_t temp_value; | |||
| bool less_than_zero = false; | |||
| if (json_column_value.is_number_integer()) { | |||
| const json json_zero = 0; | |||
| if (json_column_value < json_zero) less_than_zero = true; | |||
| temp_value = json_column_value; | |||
| } else if (json_column_value.is_string()) { | |||
| std::string string_value = json_column_value; | |||
| if (!string_value.empty() && string_value[0] == '-') { | |||
| try { | |||
| temp_value = std::stoll(string_value); | |||
| less_than_zero = true; | |||
| } catch (std::invalid_argument &e) { | |||
| MS_LOG(ERROR) << "Conversion to int failed, invalid argument."; | |||
| return FAILED; | |||
| } catch (std::out_of_range &e) { | |||
| MS_LOG(ERROR) << "Conversion to int failed, out of range."; | |||
| return FAILED; | |||
| } | |||
| } else { | |||
| try { | |||
| temp_value = static_cast<int64_t>(std::stoull(string_value)); | |||
| } catch (std::invalid_argument &e) { | |||
| MS_LOG(ERROR) << "Conversion to int failed, invalid argument."; | |||
| return FAILED; | |||
| } catch (std::out_of_range &e) { | |||
| MS_LOG(ERROR) << "Conversion to int failed, out of range."; | |||
| return FAILED; | |||
| } | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "Conversion to int failed."; | |||
| return FAILED; | |||
| } | |||
| if ((less_than_zero && temp_value < static_cast<int64_t>(std::numeric_limits<T>::min())) || | |||
| (!less_than_zero && static_cast<uint64_t>(temp_value) > static_cast<uint64_t>(std::numeric_limits<T>::max()))) { | |||
| MS_LOG(ERROR) << "Conversion to int failed. Out of range"; | |||
| return FAILED; | |||
| } | |||
| array_data[0] = static_cast<T>(temp_value); | |||
| auto data = reinterpret_cast<const unsigned char *>(array_data.get()); | |||
| *data_ptr = std::make_unique<unsigned char[]>(sizeof(T)); | |||
| for (uint32_t i = 0; i < sizeof(T); i++) { | |||
| (*data_ptr)[i] = *(data + i); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| MSRStatus ShardColumn::GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob, | |||
| const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr, | |||
| uint64_t *n_bytes) { | |||
| uint64_t offset_address = 0; | |||
| auto column_id = column_name_id_[column_name]; | |||
| if (GetColumnAddressInBlock(column_id, columns_blob, n_bytes, &offset_address) == FAILED) { | |||
| return FAILED; | |||
| } | |||
| auto column_data_type = column_data_type_[column_id]; | |||
| if (has_compress_blob_ && column_data_type == ColumnInt32) { | |||
| if (UncompressInt<int32_t>(column_id, data_ptr, columns_blob, n_bytes, offset_address) == FAILED) { | |||
| return FAILED; | |||
| } | |||
| } else if (has_compress_blob_ && column_data_type == ColumnInt64) { | |||
| if (UncompressInt<int64_t>(column_id, data_ptr, columns_blob, n_bytes, offset_address) == FAILED) { | |||
| return FAILED; | |||
| } | |||
| } else { | |||
| *data = reinterpret_cast<const unsigned char *>(&(columns_blob[offset_address])); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| ColumnCategory ShardColumn::CheckColumnName(const std::string &column_name) { | |||
| auto it_column = column_name_id_.find(column_name); | |||
| if (it_column == column_name_id_.end()) { | |||
| return ColumnNotFound; | |||
| } | |||
| auto it_blob = blob_column_id_.find(column_name); | |||
| return it_blob == blob_column_id_.end() ? ColumnInRaw : ColumnInBlob; | |||
| } | |||
| std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob) { | |||
| // Skip if no compress columns | |||
| if (!CheckCompressBlob()) return blob; | |||
| std::vector<uint8_t> dst_blob; | |||
| uint64_t i_src = 0; | |||
| for (int64_t i = 0; i < num_blob_column_; i++) { | |||
| // Get column data type | |||
| auto src_data_type = column_data_type_[column_name_id_[blob_column_[i]]]; | |||
| auto int_type = src_data_type == ColumnInt32 ? kInt32Type : kInt64Type; | |||
| // Compress and return is blob has 1 column only | |||
| if (num_blob_column_ == 1) { | |||
| return CompressInt(blob, int_type); | |||
| } | |||
| // Just copy and continue if column dat type is not int32/int64 | |||
| uint64_t num_bytes = BytesBigToUInt64(blob, i_src, kInt64Type); | |||
| if (src_data_type != ColumnInt32 && src_data_type != ColumnInt64) { | |||
| dst_blob.insert(dst_blob.end(), blob.begin() + i_src, blob.begin() + i_src + kInt64Len + num_bytes); | |||
| i_src += kInt64Len + num_bytes; | |||
| continue; | |||
| } | |||
| // Get column slice in source blob | |||
| std::vector<uint8_t> blob_slice(blob.begin() + i_src + kInt64Len, blob.begin() + i_src + kInt64Len + num_bytes); | |||
| // Compress column | |||
| auto dst_blob_slice = CompressInt(blob_slice, int_type); | |||
| // Get new column size | |||
| auto new_blob_size = UIntToBytesBig(dst_blob_slice.size(), kInt64Type); | |||
| // Append new colmn size | |||
| dst_blob.insert(dst_blob.end(), new_blob_size.begin(), new_blob_size.end()); | |||
| // Append new colmn data | |||
| dst_blob.insert(dst_blob.end(), dst_blob_slice.begin(), dst_blob_slice.end()); | |||
| i_src += kInt64Len + num_bytes; | |||
| } | |||
| MS_LOG(DEBUG) << "Compress all blob from " << blob.size() << " to " << dst_blob.size() << "."; | |||
| return dst_blob; | |||
| } | |||
| vector<uint8_t> ShardColumn::CompressInt(const vector<uint8_t> &src_bytes, const IntegerType &int_type) { | |||
| uint64_t i_size = kUnsignedOne << int_type; | |||
| // Get number of elements | |||
| uint64_t src_n_int = src_bytes.size() / i_size; | |||
| // Calculate bitmap size (bytes) | |||
| uint64_t bitmap_size = (src_n_int + kNumDataOfByte - 1) / kNumDataOfByte; | |||
| // Initilize destination blob, more space than needed, will be resized | |||
| vector<uint8_t> dst_bytes(kBytesOfColumnLen + bitmap_size + src_bytes.size(), 0); | |||
| // Write number of elements to destination blob | |||
| vector<uint8_t> size_by_bytes = UIntToBytesBig(src_n_int, kInt32Type); | |||
| for (uint64_t n = 0; n < kBytesOfColumnLen; n++) { | |||
| dst_bytes[n] = size_by_bytes[n]; | |||
| } | |||
| // Write compressed int | |||
| uint64_t i_dst = kBytesOfColumnLen + bitmap_size; | |||
| for (uint64_t i = 0; i < src_n_int; i++) { | |||
| // Initialize destination data type | |||
| IntegerType dst_int_type = kInt8Type; | |||
| // Shift to next int position | |||
| uint64_t pos = i * (kUnsignedOne << int_type); | |||
| // Narrow down this int | |||
| int64_t i_n = BytesLittleToMinIntType(src_bytes, pos, int_type, &dst_int_type); | |||
| // Write this int to destination blob | |||
| uint64_t u_n = *reinterpret_cast<uint64_t *>(&i_n); | |||
| auto temp_bytes = UIntToBytesLittle(u_n, dst_int_type); | |||
| for (uint64_t j = 0; j < (kUnsignedOne << dst_int_type); j++) { | |||
| dst_bytes[i_dst++] = temp_bytes[j]; | |||
| } | |||
| // Update date type in bit map | |||
| dst_bytes[i / kNumDataOfByte + kBytesOfColumnLen] |= | |||
| (dst_int_type << (kDataTypeBits * (kNumDataOfByte - kUnsignedOne - (i % kNumDataOfByte)))); | |||
| } | |||
| // Resize destination blob | |||
| dst_bytes.resize(i_dst); | |||
| MS_LOG(DEBUG) << "Compress blob field from " << src_bytes.size() << " to " << dst_bytes.size() << "."; | |||
| return dst_bytes; | |||
| } | |||
| MSRStatus ShardColumn::GetColumnAddressInBlock(const uint64_t &column_id, const std::vector<uint8_t> &columns_blob, | |||
| uint64_t *num_bytes, uint64_t *shift_idx) { | |||
| if (num_blob_column_ == 1) { | |||
| *num_bytes = columns_blob.size(); | |||
| *shift_idx = 0; | |||
| return SUCCESS; | |||
| } | |||
| auto blob_id = blob_column_id_[column_name_[column_id]]; | |||
| for (int32_t i = 0; i < blob_id; i++) { | |||
| *shift_idx += kInt64Len + BytesBigToUInt64(columns_blob, *shift_idx, kInt64Type); | |||
| } | |||
| *num_bytes = BytesBigToUInt64(columns_blob, *shift_idx, kInt64Type); | |||
| (*shift_idx) += kInt64Len; | |||
| return SUCCESS; | |||
| } | |||
| template <typename T> | |||
| MSRStatus ShardColumn::UncompressInt(const uint64_t &column_id, std::unique_ptr<unsigned char[]> *data_ptr, | |||
| const std::vector<uint8_t> &columns_blob, uint64_t *num_bytes, | |||
| uint64_t shift_idx) { | |||
| auto num_elements = BytesBigToUInt64(columns_blob, shift_idx, kInt32Type); | |||
| *num_bytes = sizeof(T) * num_elements; | |||
| // Parse integer array | |||
| uint64_t i_source = shift_idx + kBytesOfColumnLen + (num_elements + kNumDataOfByte - 1) / kNumDataOfByte; | |||
| auto array_data = std::make_unique<T[]>(num_elements); | |||
| for (uint64_t i = 0; i < num_elements; i++) { | |||
| uint8_t iBitMap = columns_blob[shift_idx + kBytesOfColumnLen + i / kNumDataOfByte]; | |||
| uint64_t i_type = (iBitMap >> ((kNumDataOfByte - 1 - (i % kNumDataOfByte)) * kDataTypeBits)) & kDataTypeBitMask; | |||
| auto mr_int_type = static_cast<IntegerType>(i_type); | |||
| int64_t i64 = BytesLittleToMinIntType(columns_blob, i_source, mr_int_type); | |||
| i_source += (kUnsignedOne << i_type); | |||
| array_data[i] = static_cast<T>(i64); | |||
| } | |||
| auto data = reinterpret_cast<const unsigned char *>(array_data.get()); | |||
| *data_ptr = std::make_unique<unsigned char[]>(*num_bytes); | |||
| memcpy(data_ptr->get(), data, *num_bytes); | |||
| return SUCCESS; | |||
| } | |||
| uint64_t ShardColumn::BytesBigToUInt64(const std::vector<uint8_t> &bytes_array, const uint64_t &pos, | |||
| const IntegerType &i_type) { | |||
| uint64_t result = 0; | |||
| for (uint64_t i = 0; i < (kUnsignedOne << i_type); i++) { | |||
| result = (result << kBitsOfByte) + bytes_array[pos + i]; | |||
| } | |||
| return result; | |||
| } | |||
| std::vector<uint8_t> ShardColumn::UIntToBytesBig(uint64_t value, const IntegerType &i_type) { | |||
| uint64_t n_bytes = kUnsignedOne << i_type; | |||
| std::vector<uint8_t> result(n_bytes, 0); | |||
| for (uint64_t i = 0; i < n_bytes; i++) { | |||
| result[n_bytes - 1 - i] = value & std::numeric_limits<uint8_t>::max(); | |||
| value >>= kBitsOfByte; | |||
| } | |||
| return result; | |||
| } | |||
| std::vector<uint8_t> ShardColumn::UIntToBytesLittle(uint64_t value, const IntegerType &i_type) { | |||
| uint64_t n_bytes = kUnsignedOne << i_type; | |||
| std::vector<uint8_t> result(n_bytes, 0); | |||
| for (uint64_t i = 0; i < n_bytes; i++) { | |||
| result[i] = value & std::numeric_limits<uint8_t>::max(); | |||
| value >>= kBitsOfByte; | |||
| } | |||
| return result; | |||
| } | |||
| int64_t ShardColumn::BytesLittleToMinIntType(const std::vector<uint8_t> &bytes_array, const uint64_t &pos, | |||
| const IntegerType &src_i_type, IntegerType *dst_i_type) { | |||
| uint64_t u_temp = 0; | |||
| for (uint64_t i = 0; i < (kUnsignedOne << src_i_type); i++) { | |||
| u_temp = (u_temp << kBitsOfByte) + bytes_array[pos + (kUnsignedOne << src_i_type) - kUnsignedOne - i]; | |||
| } | |||
| int64_t i_out; | |||
| switch (src_i_type) { | |||
| case kInt8Type: { | |||
| i_out = (int8_t)(u_temp & std::numeric_limits<uint8_t>::max()); | |||
| break; | |||
| } | |||
| case kInt16Type: { | |||
| i_out = (int16_t)(u_temp & std::numeric_limits<uint16_t>::max()); | |||
| break; | |||
| } | |||
| case kInt32Type: { | |||
| i_out = (int32_t)(u_temp & std::numeric_limits<uint32_t>::max()); | |||
| break; | |||
| } | |||
| case kInt64Type: { | |||
| i_out = (int64_t)(u_temp & std::numeric_limits<uint64_t>::max()); | |||
| break; | |||
| } | |||
| default: { | |||
| i_out = 0; | |||
| } | |||
| } | |||
| if (!dst_i_type) { | |||
| return i_out; | |||
| } | |||
| if (i_out >= static_cast<int64_t>(std::numeric_limits<int8_t>::min()) && | |||
| i_out <= static_cast<int64_t>(std::numeric_limits<int8_t>::max())) { | |||
| *dst_i_type = kInt8Type; | |||
| } else if (i_out >= static_cast<int64_t>(std::numeric_limits<int16_t>::min()) && | |||
| i_out <= static_cast<int64_t>(std::numeric_limits<int16_t>::max())) { | |||
| *dst_i_type = kInt16Type; | |||
| } else if (i_out >= static_cast<int64_t>(std::numeric_limits<int32_t>::min()) && | |||
| i_out <= static_cast<int64_t>(std::numeric_limits<int32_t>::max())) { | |||
| *dst_i_type = kInt32Type; | |||
| } else { | |||
| *dst_i_type = kInt64Type; | |||
| } | |||
| return i_out; | |||
| } | |||
| } // namespace mindrecord | |||
| } // namespace mindspore | |||
| @@ -201,9 +201,9 @@ void ShardHeader::GetHeadersOneTask(int start, int end, std::vector<json> &heade | |||
| json header; | |||
| header = ret.second; | |||
| header["shard_addresses"] = realAddresses; | |||
| if (header["version"] != version_) { | |||
| if (std::find(kSupportedVersion.begin(), kSupportedVersion.end(), header["version"]) == kSupportedVersion.end()) { | |||
| MS_LOG(ERROR) << "Version wrong, file version is: " << header["version"].dump() | |||
| << ", lib version is: " << version_; | |||
| << ", lib version is: " << kVersion; | |||
| thread_status = true; | |||
| return; | |||
| } | |||
| @@ -339,7 +339,7 @@ std::vector<std::string> ShardHeader::SerializeHeader() { | |||
| s += "\"shard_addresses\":" + address + ","; | |||
| s += "\"shard_id\":" + std::to_string(shardId) + ","; | |||
| s += "\"statistics\":" + stats + ","; | |||
| s += "\"version\":\"" + version_ + "\""; | |||
| s += "\"version\":\"" + std::string(kVersion) + "\""; | |||
| s += "}"; | |||
| header.emplace_back(s); | |||
| } | |||
| @@ -97,16 +97,13 @@ def populate_data(raw, blob, columns, blob_fields, schema): | |||
| if not blob_fields: | |||
| return raw | |||
| # Get the order preserving sequence of columns in blob | |||
| ordered_columns = [] | |||
| loaded_columns = [] | |||
| if columns: | |||
| for blob_field in blob_fields: | |||
| if blob_field in columns: | |||
| ordered_columns.append(blob_field) | |||
| for column in columns: | |||
| if column in blob_fields: | |||
| loaded_columns.append(column) | |||
| else: | |||
| ordered_columns = blob_fields | |||
| blob_bytes = bytes(blob) | |||
| loaded_columns = blob_fields | |||
| def _render_raw(field, blob_data): | |||
| data_type = schema[field]['type'] | |||
| @@ -119,24 +116,6 @@ def populate_data(raw, blob, columns, blob_fields, schema): | |||
| else: | |||
| raw[field] = blob_data | |||
| if len(blob_fields) == 1: | |||
| if len(ordered_columns) == 1: | |||
| _render_raw(blob_fields[0], blob_bytes) | |||
| return raw | |||
| return raw | |||
| def _int_from_bytes(xbytes: bytes) -> int: | |||
| return int.from_bytes(xbytes, 'big') | |||
| def _blob_at_position(pos): | |||
| start = 0 | |||
| for _ in range(pos): | |||
| n_bytes = _int_from_bytes(blob_bytes[start : start + 8]) | |||
| start += 8 + n_bytes | |||
| n_bytes = _int_from_bytes(blob_bytes[start : start + 8]) | |||
| start += 8 | |||
| return blob_bytes[start : start + n_bytes] | |||
| for i, blob_field in enumerate(ordered_columns): | |||
| _render_raw(blob_field, _blob_at_position(i)) | |||
| for i, blob_field in enumerate(loaded_columns): | |||
| _render_raw(blob_field, bytes(blob[i])) | |||
| return raw | |||
| @@ -35,6 +35,7 @@ CV1_FILE_NAME = "../data/mindrecord/imagenet1.mindrecord" | |||
| CV2_FILE_NAME = "../data/mindrecord/imagenet2.mindrecord" | |||
| CV_DIR_NAME = "../data/mindrecord/testImageNetData" | |||
| NLP_FILE_NAME = "../data/mindrecord/aclImdb.mindrecord" | |||
| OLD_NLP_FILE_NAME = "../data/mindrecord/testOldVersion/aclImdb.mindrecord" | |||
| NLP_FILE_POS = "../data/mindrecord/testAclImdbData/pos" | |||
| NLP_FILE_VOCAB = "../data/mindrecord/testAclImdbData/vocab.txt" | |||
| @@ -46,7 +47,8 @@ def add_and_remove_cv_file(): | |||
| for x in range(FILES_NUM)] | |||
| for x in paths: | |||
| os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None | |||
| os.remove("{}.db".format(x)) if os.path.exists("{}.db".format(x)) else None | |||
| os.remove("{}.db".format(x)) if os.path.exists( | |||
| "{}.db".format(x)) else None | |||
| writer = FileWriter(CV_FILE_NAME, FILES_NUM) | |||
| data = get_data(CV_DIR_NAME) | |||
| cv_schema_json = {"id": {"type": "int32"}, | |||
| @@ -96,13 +98,105 @@ def add_and_remove_nlp_file(): | |||
| os.remove("{}.db".format(x)) | |||
| @pytest.fixture | |||
| def add_and_remove_nlp_compress_file(): | |||
| """add/remove nlp file""" | |||
| paths = ["{}{}".format(NLP_FILE_NAME, str(x).rjust(1, '0')) | |||
| for x in range(FILES_NUM)] | |||
| for x in paths: | |||
| if os.path.exists("{}".format(x)): | |||
| os.remove("{}".format(x)) | |||
| if os.path.exists("{}.db".format(x)): | |||
| os.remove("{}.db".format(x)) | |||
| writer = FileWriter(NLP_FILE_NAME, FILES_NUM) | |||
| data = [] | |||
| for row_id in range(16): | |||
| data.append({ | |||
| "label": row_id, | |||
| "array_a": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, | |||
| 255, 256, -32768, 32767, -32769, 32768, -2147483648, | |||
| 2147483647], dtype=np.int32), [-1]), | |||
| "array_b": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, 255, | |||
| 256, -32768, 32767, -32769, 32768, -2147483648, 2147483647, -2147483649, 2147483649, -922337036854775808, 9223372036854775807]), [1, -1]), | |||
| "array_c": str.encode("nlp data"), | |||
| "array_d": np.reshape(np.array([[-10, -127], [10, 127]]), [2, -1]) | |||
| }) | |||
| nlp_schema_json = {"label": {"type": "int32"}, | |||
| "array_a": {"type": "int32", | |||
| "shape": [-1]}, | |||
| "array_b": {"type": "int64", | |||
| "shape": [1, -1]}, | |||
| "array_c": {"type": "bytes"}, | |||
| "array_d": {"type": "int64", | |||
| "shape": [2, -1]} | |||
| } | |||
| writer.set_header_size(1 << 14) | |||
| writer.set_page_size(1 << 15) | |||
| writer.add_schema(nlp_schema_json, "nlp_schema") | |||
| writer.write_raw_data(data) | |||
| writer.commit() | |||
| yield "yield_nlp_data" | |||
| for x in paths: | |||
| os.remove("{}".format(x)) | |||
| os.remove("{}.db".format(x)) | |||
| def test_nlp_compress_data(add_and_remove_nlp_compress_file): | |||
| """tutorial for nlp minderdataset.""" | |||
| data = [] | |||
| for row_id in range(16): | |||
| data.append({ | |||
| "label": row_id, | |||
| "array_a": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, | |||
| 255, 256, -32768, 32767, -32769, 32768, -2147483648, | |||
| 2147483647], dtype=np.int32), [-1]), | |||
| "array_b": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, 255, | |||
| 256, -32768, 32767, -32769, 32768, -2147483648, 2147483647, -2147483649, 2147483649, -922337036854775808, 9223372036854775807]), [1, -1]), | |||
| "array_c": str.encode("nlp data"), | |||
| "array_d": np.reshape(np.array([[-10, -127], [10, 127]]), [2, -1]) | |||
| }) | |||
| num_readers = 1 | |||
| data_set = ds.MindDataset( | |||
| NLP_FILE_NAME + "0", None, num_readers, shuffle=False) | |||
| assert data_set.get_dataset_size() == 16 | |||
| num_iter = 0 | |||
| for x, item in zip(data, data_set.create_dict_iterator()): | |||
| assert (item["array_a"] == x["array_a"]).all() | |||
| assert (item["array_b"] == x["array_b"]).all() | |||
| assert item["array_c"].tobytes() == x["array_c"] | |||
| assert (item["array_d"] == x["array_d"]).all() | |||
| assert item["label"] == x["label"] | |||
| num_iter += 1 | |||
| assert num_iter == 16 | |||
| def test_nlp_compress_data_old_version(add_and_remove_nlp_compress_file): | |||
| """tutorial for nlp minderdataset.""" | |||
| num_readers = 1 | |||
| data_set = ds.MindDataset( | |||
| NLP_FILE_NAME + "0", None, num_readers, shuffle=False) | |||
| old_data_set = ds.MindDataset( | |||
| OLD_NLP_FILE_NAME + "0", None, num_readers, shuffle=False) | |||
| assert old_data_set.get_dataset_size() == 16 | |||
| num_iter = 0 | |||
| for x, item in zip(old_data_set.create_dict_iterator(), data_set.create_dict_iterator()): | |||
| assert (item["array_a"] == x["array_a"]).all() | |||
| assert (item["array_b"] == x["array_b"]).all() | |||
| assert (item["array_c"] == x["array_c"]).all() | |||
| assert (item["array_d"] == x["array_d"]).all() | |||
| assert item["label"] == x["label"] | |||
| num_iter += 1 | |||
| assert num_iter == 16 | |||
| def test_cv_minddataset_writer_tutorial(): | |||
| """tutorial for cv dataset writer.""" | |||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | |||
| for x in range(FILES_NUM)] | |||
| for x in paths: | |||
| os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None | |||
| os.remove("{}.db".format(x)) if os.path.exists("{}.db".format(x)) else None | |||
| os.remove("{}.db".format(x)) if os.path.exists( | |||
| "{}.db".format(x)) else None | |||
| writer = FileWriter(CV_FILE_NAME, FILES_NUM) | |||
| data = get_data(CV_DIR_NAME) | |||
| cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"}, | |||
| @@ -127,8 +221,10 @@ def test_cv_minddataset_partition_tutorial(add_and_remove_cv_file): | |||
| num_shards=num_shards, shard_id=partition_id) | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| logger.info("-------------- partition : {} ------------------------".format(partition_id)) | |||
| logger.info("-------------- item[label]: {} -----------------------".format(item["label"])) | |||
| logger.info( | |||
| "-------------- partition : {} ------------------------".format(partition_id)) | |||
| logger.info( | |||
| "-------------- item[label]: {} -----------------------".format(item["label"])) | |||
| num_iter += 1 | |||
| return num_iter | |||
| @@ -147,9 +243,12 @@ def test_cv_minddataset_dataset_size(add_and_remove_cv_file): | |||
| data_set = data_set.repeat(repeat_num) | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| logger.info("-------------- get dataset size {} -----------------".format(num_iter)) | |||
| logger.info("-------------- item[label]: {} ---------------------".format(item["label"])) | |||
| logger.info("-------------- item[data]: {} ----------------------".format(item["data"])) | |||
| logger.info( | |||
| "-------------- get dataset size {} -----------------".format(num_iter)) | |||
| logger.info( | |||
| "-------------- item[label]: {} ---------------------".format(item["label"])) | |||
| logger.info( | |||
| "-------------- item[data]: {} ----------------------".format(item["data"])) | |||
| num_iter += 1 | |||
| assert num_iter == 20 | |||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | |||
| @@ -163,17 +262,22 @@ def test_cv_minddataset_repeat_reshuffle(add_and_remove_cv_file): | |||
| num_readers = 4 | |||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers) | |||
| decode_op = vision.Decode() | |||
| data_set = data_set.map(input_columns=["data"], operations=decode_op, num_parallel_workers=2) | |||
| data_set = data_set.map( | |||
| input_columns=["data"], operations=decode_op, num_parallel_workers=2) | |||
| resize_op = vision.Resize((32, 32), interpolation=Inter.LINEAR) | |||
| data_set = data_set.map(input_columns="data", operations=resize_op, num_parallel_workers=2) | |||
| data_set = data_set.map(input_columns="data", | |||
| operations=resize_op, num_parallel_workers=2) | |||
| data_set = data_set.batch(2) | |||
| data_set = data_set.repeat(2) | |||
| num_iter = 0 | |||
| labels = [] | |||
| for item in data_set.create_dict_iterator(): | |||
| logger.info("-------------- get dataset size {} -----------------".format(num_iter)) | |||
| logger.info("-------------- item[label]: {} ---------------------".format(item["label"])) | |||
| logger.info("-------------- item[data]: {} ----------------------".format(item["data"])) | |||
| logger.info( | |||
| "-------------- get dataset size {} -----------------".format(num_iter)) | |||
| logger.info( | |||
| "-------------- item[label]: {} ---------------------".format(item["label"])) | |||
| logger.info( | |||
| "-------------- item[data]: {} ----------------------".format(item["data"])) | |||
| num_iter += 1 | |||
| labels.append(item["label"]) | |||
| assert num_iter == 10 | |||
| @@ -189,15 +293,20 @@ def test_cv_minddataset_batch_size_larger_than_records(add_and_remove_cv_file): | |||
| num_readers = 4 | |||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers) | |||
| decode_op = vision.Decode() | |||
| data_set = data_set.map(input_columns=["data"], operations=decode_op, num_parallel_workers=2) | |||
| data_set = data_set.map( | |||
| input_columns=["data"], operations=decode_op, num_parallel_workers=2) | |||
| resize_op = vision.Resize((32, 32), interpolation=Inter.LINEAR) | |||
| data_set = data_set.map(input_columns="data", operations=resize_op, num_parallel_workers=2) | |||
| data_set = data_set.map(input_columns="data", | |||
| operations=resize_op, num_parallel_workers=2) | |||
| data_set = data_set.batch(32, drop_remainder=True) | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| logger.info("-------------- get dataset size {} -----------------".format(num_iter)) | |||
| logger.info("-------------- item[label]: {} ---------------------".format(item["label"])) | |||
| logger.info("-------------- item[data]: {} ----------------------".format(item["data"])) | |||
| logger.info( | |||
| "-------------- get dataset size {} -----------------".format(num_iter)) | |||
| logger.info( | |||
| "-------------- item[label]: {} ---------------------".format(item["label"])) | |||
| logger.info( | |||
| "-------------- item[data]: {} ----------------------".format(item["data"])) | |||
| num_iter += 1 | |||
| assert num_iter == 0 | |||
| @@ -206,7 +315,8 @@ def test_cv_minddataset_issue_888(add_and_remove_cv_file): | |||
| """issue 888 test.""" | |||
| columns_list = ["data", "label"] | |||
| num_readers = 2 | |||
| data = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, shuffle=False, num_shards=5, shard_id=1) | |||
| data = ds.MindDataset(CV_FILE_NAME + "0", columns_list, | |||
| num_readers, shuffle=False, num_shards=5, shard_id=1) | |||
| data = data.shuffle(2) | |||
| data = data.repeat(9) | |||
| num_iter = 0 | |||
| @@ -226,9 +336,12 @@ def test_cv_minddataset_blockreader_tutorial(add_and_remove_cv_file): | |||
| data_set = data_set.repeat(repeat_num) | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| logger.info("-------------- block reader repeat tow {} -----------------".format(num_iter)) | |||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) | |||
| logger.info( | |||
| "-------------- block reader repeat tow {} -----------------".format(num_iter)) | |||
| logger.info( | |||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| logger.info( | |||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||
| num_iter += 1 | |||
| assert num_iter == 20 | |||
| @@ -244,10 +357,14 @@ def test_cv_minddataset_blockreader_some_field_not_in_index_tutorial(add_and_rem | |||
| data_set = data_set.repeat(repeat_num) | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| logger.info("-------------- block reader repeat tow {} -----------------".format(num_iter)) | |||
| logger.info("-------------- item[id]: {} ----------------------------".format(item["id"])) | |||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) | |||
| logger.info( | |||
| "-------------- block reader repeat tow {} -----------------".format(num_iter)) | |||
| logger.info( | |||
| "-------------- item[id]: {} ----------------------------".format(item["id"])) | |||
| logger.info( | |||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| logger.info( | |||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||
| num_iter += 1 | |||
| assert num_iter == 20 | |||
| @@ -256,15 +373,21 @@ def test_cv_minddataset_reader_file_list(add_and_remove_cv_file): | |||
| """tutorial for cv minderdataset.""" | |||
| columns_list = ["data", "file_name", "label"] | |||
| num_readers = 4 | |||
| data_set = ds.MindDataset([CV_FILE_NAME + str(x) for x in range(FILES_NUM)], columns_list, num_readers) | |||
| data_set = ds.MindDataset([CV_FILE_NAME + str(x) | |||
| for x in range(FILES_NUM)], columns_list, num_readers) | |||
| assert data_set.get_dataset_size() == 10 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||
| logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) | |||
| logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) | |||
| logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| logger.info( | |||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||
| logger.info( | |||
| "-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) | |||
| logger.info( | |||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||
| logger.info( | |||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||
| logger.info( | |||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| num_iter += 1 | |||
| assert num_iter == 10 | |||
| @@ -277,11 +400,16 @@ def test_cv_minddataset_reader_one_partition(add_and_remove_cv_file): | |||
| assert data_set.get_dataset_size() < 10 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||
| logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) | |||
| logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) | |||
| logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| logger.info( | |||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||
| logger.info( | |||
| "-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) | |||
| logger.info( | |||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||
| logger.info( | |||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||
| logger.info( | |||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| num_iter += 1 | |||
| assert num_iter < 10 | |||
| @@ -324,11 +452,16 @@ def test_cv_minddataset_reader_two_dataset(add_and_remove_cv_file): | |||
| assert data_set.get_dataset_size() == 30 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||
| logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) | |||
| logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) | |||
| logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| logger.info( | |||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||
| logger.info( | |||
| "-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) | |||
| logger.info( | |||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||
| logger.info( | |||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||
| logger.info( | |||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| num_iter += 1 | |||
| assert num_iter == 30 | |||
| if os.path.exists(CV1_FILE_NAME): | |||
| @@ -346,7 +479,8 @@ def test_cv_minddataset_reader_two_dataset_partition(add_and_remove_cv_file): | |||
| for x in range(FILES_NUM)] | |||
| for x in paths: | |||
| os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None | |||
| os.remove("{}.db".format(x)) if os.path.exists("{}.db".format(x)) else None | |||
| os.remove("{}.db".format(x)) if os.path.exists( | |||
| "{}.db".format(x)) else None | |||
| writer = FileWriter(CV1_FILE_NAME, FILES_NUM) | |||
| data = get_data(CV_DIR_NAME) | |||
| cv_schema_json = {"id": {"type": "int32"}, | |||
| @@ -365,11 +499,16 @@ def test_cv_minddataset_reader_two_dataset_partition(add_and_remove_cv_file): | |||
| assert data_set.get_dataset_size() < 20 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||
| logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) | |||
| logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) | |||
| logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| logger.info( | |||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||
| logger.info( | |||
| "-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) | |||
| logger.info( | |||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||
| logger.info( | |||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||
| logger.info( | |||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| num_iter += 1 | |||
| assert num_iter < 20 | |||
| for x in paths: | |||
| @@ -385,11 +524,16 @@ def test_cv_minddataset_reader_basic_tutorial(add_and_remove_cv_file): | |||
| assert data_set.get_dataset_size() == 10 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||
| logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) | |||
| logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) | |||
| logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| logger.info( | |||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||
| logger.info( | |||
| "-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) | |||
| logger.info( | |||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||
| logger.info( | |||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||
| logger.info( | |||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| num_iter += 1 | |||
| assert num_iter == 10 | |||
| @@ -401,10 +545,14 @@ def test_nlp_minddataset_reader_basic_tutorial(add_and_remove_nlp_file): | |||
| assert data_set.get_dataset_size() == 10 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||
| logger.info("-------------- num_iter: {} ------------------------".format(num_iter)) | |||
| logger.info("-------------- item[id]: {} ------------------------".format(item["id"])) | |||
| logger.info("-------------- item[rating]: {} --------------------".format(item["rating"])) | |||
| logger.info( | |||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||
| logger.info( | |||
| "-------------- num_iter: {} ------------------------".format(num_iter)) | |||
| logger.info( | |||
| "-------------- item[id]: {} ------------------------".format(item["id"])) | |||
| logger.info( | |||
| "-------------- item[rating]: {} --------------------".format(item["rating"])) | |||
| logger.info("-------------- item[input_ids]: {}, shape: {} -----------------".format( | |||
| item["input_ids"], item["input_ids"].shape)) | |||
| logger.info("-------------- item[input_mask]: {}, shape: {} -----------------".format( | |||
| @@ -445,10 +593,13 @@ def test_cv_minddataset_reader_basic_tutorial_5_epoch_with_batch(add_and_remove_ | |||
| # define map operations | |||
| decode_op = vision.Decode() | |||
| resize_op = vision.Resize((resize_height, resize_width), ds.transforms.vision.Inter.LINEAR) | |||
| resize_op = vision.Resize( | |||
| (resize_height, resize_width), ds.transforms.vision.Inter.LINEAR) | |||
| data_set = data_set.map(input_columns=["data"], operations=decode_op, num_parallel_workers=4) | |||
| data_set = data_set.map(input_columns=["data"], operations=resize_op, num_parallel_workers=4) | |||
| data_set = data_set.map( | |||
| input_columns=["data"], operations=decode_op, num_parallel_workers=4) | |||
| data_set = data_set.map( | |||
| input_columns=["data"], operations=resize_op, num_parallel_workers=4) | |||
| data_set = data_set.batch(2) | |||
| assert data_set.get_dataset_size() == 5 | |||
| @@ -468,11 +619,16 @@ def test_cv_minddataset_reader_no_columns(add_and_remove_cv_file): | |||
| assert data_set.get_dataset_size() == 10 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||
| logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) | |||
| logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) | |||
| logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| logger.info( | |||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||
| logger.info( | |||
| "-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) | |||
| logger.info( | |||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||
| logger.info( | |||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||
| logger.info( | |||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| num_iter += 1 | |||
| assert num_iter == 10 | |||
| @@ -486,11 +642,16 @@ def test_cv_minddataset_reader_repeat_tutorial(add_and_remove_cv_file): | |||
| data_set = data_set.repeat(repeat_num) | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| logger.info("-------------- repeat two test {} ------------------------".format(num_iter)) | |||
| logger.info("-------------- len(item[data]): {} -----------------------".format(len(item["data"]))) | |||
| logger.info("-------------- item[data]: {} ----------------------------".format(item["data"])) | |||
| logger.info("-------------- item[file_name]: {} -----------------------".format(item["file_name"])) | |||
| logger.info("-------------- item[label]: {} ---------------------------".format(item["label"])) | |||
| logger.info( | |||
| "-------------- repeat two test {} ------------------------".format(num_iter)) | |||
| logger.info( | |||
| "-------------- len(item[data]): {} -----------------------".format(len(item["data"]))) | |||
| logger.info( | |||
| "-------------- item[data]: {} ----------------------------".format(item["data"])) | |||
| logger.info( | |||
| "-------------- item[file_name]: {} -----------------------".format(item["file_name"])) | |||
| logger.info( | |||
| "-------------- item[label]: {} ---------------------------".format(item["label"])) | |||
| num_iter += 1 | |||
| assert num_iter == 20 | |||
| @@ -599,7 +760,8 @@ def get_mkv_data(dir_name): | |||
| "id": index} | |||
| data_list.append(data_json) | |||
| index += 1 | |||
| logger.info('{} images are missing'.format(len(file_list) - len(data_list))) | |||
| logger.info('{} images are missing'.format( | |||
| len(file_list) - len(data_list))) | |||
| return data_list | |||
| @@ -686,6 +848,10 @@ def inputs(vectors, maxlen=50): | |||
| def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): | |||
| mindrecord_file_name = "test.mindrecord" | |||
| if os.path.exists("{}".format(mindrecord_file_name)): | |||
| os.remove("{}".format(mindrecord_file_name)) | |||
| if os.path.exists("{}.db".format(mindrecord_file_name)): | |||
| os.remove("{}.db".format(x)) | |||
| data = [{"file_name": "001.jpg", "label": 4, | |||
| "image1": bytes("image1 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image1 bytes def", encoding='UTF-8'), | |||
| @@ -782,7 +948,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): | |||
| data_value_to_list = [] | |||
| for item in data: | |||
| new_data = {} | |||
| new_data['file_name'] = np.asarray(list(bytes(item["file_name"], encoding='utf-8')), dtype=np.uint8) | |||
| new_data['file_name'] = np.asarray( | |||
| list(bytes(item["file_name"], encoding='utf-8')), dtype=np.uint8) | |||
| new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32) | |||
| new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8) | |||
| new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8) | |||
| @@ -807,7 +974,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): | |||
| assert len(item) == 13 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| assert (item[field] == | |||
| data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| @@ -815,7 +983,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): | |||
| num_readers = 2 | |||
| data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | |||
| columns_list=["source_sos_ids", "source_sos_mask", "target_sos_ids"], | |||
| columns_list=["source_sos_ids", | |||
| "source_sos_mask", "target_sos_ids"], | |||
| num_parallel_workers=num_readers, | |||
| shuffle=False) | |||
| assert data_set.get_dataset_size() == 6 | |||
| @@ -832,7 +1001,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): | |||
| num_readers = 1 | |||
| data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | |||
| columns_list=["image2", "source_sos_mask", "image3", "target_sos_ids"], | |||
| columns_list=[ | |||
| "image2", "source_sos_mask", "image3", "target_sos_ids"], | |||
| num_parallel_workers=num_readers, | |||
| shuffle=False) | |||
| assert data_set.get_dataset_size() == 6 | |||
| @@ -841,7 +1011,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): | |||
| assert len(item) == 4 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| assert (item[field] == | |||
| data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| @@ -849,7 +1020,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): | |||
| num_readers = 3 | |||
| data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | |||
| columns_list=["target_sos_ids", "image4", "source_sos_ids"], | |||
| columns_list=["target_sos_ids", | |||
| "image4", "source_sos_ids"], | |||
| num_parallel_workers=num_readers, | |||
| shuffle=False) | |||
| assert data_set.get_dataset_size() == 6 | |||
| @@ -858,7 +1030,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): | |||
| assert len(item) == 3 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| assert (item[field] == | |||
| data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| @@ -866,7 +1039,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): | |||
| num_readers = 3 | |||
| data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | |||
| columns_list=["target_sos_ids", "image5", "image4", "image3", "source_sos_ids"], | |||
| columns_list=["target_sos_ids", "image5", | |||
| "image4", "image3", "source_sos_ids"], | |||
| num_parallel_workers=num_readers, | |||
| shuffle=False) | |||
| assert data_set.get_dataset_size() == 6 | |||
| @@ -875,7 +1049,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): | |||
| assert len(item) == 5 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| assert (item[field] == | |||
| data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| @@ -883,7 +1058,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): | |||
| num_readers = 1 | |||
| data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | |||
| columns_list=["target_eos_mask", "image5", "image2", "source_sos_mask", "label"], | |||
| columns_list=["target_eos_mask", "image5", | |||
| "image2", "source_sos_mask", "label"], | |||
| num_parallel_workers=num_readers, | |||
| shuffle=False) | |||
| assert data_set.get_dataset_size() == 6 | |||
| @@ -892,7 +1068,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): | |||
| assert len(item) == 5 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| assert (item[field] == | |||
| data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| @@ -910,7 +1087,8 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): | |||
| assert len(item) == 11 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| assert (item[field] == | |||
| data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| @@ -975,7 +1153,8 @@ def test_write_with_multi_bytes_and_MindDataset(): | |||
| data_value_to_list = [] | |||
| for item in data: | |||
| new_data = {} | |||
| new_data['file_name'] = np.asarray(list(bytes(item["file_name"], encoding='utf-8')), dtype=np.uint8) | |||
| new_data['file_name'] = np.asarray( | |||
| list(bytes(item["file_name"], encoding='utf-8')), dtype=np.uint8) | |||
| new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32) | |||
| new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8) | |||
| new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8) | |||
| @@ -994,7 +1173,8 @@ def test_write_with_multi_bytes_and_MindDataset(): | |||
| assert len(item) == 7 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| assert (item[field] == | |||
| data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| @@ -1011,7 +1191,8 @@ def test_write_with_multi_bytes_and_MindDataset(): | |||
| assert len(item) == 3 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| assert (item[field] == | |||
| data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| @@ -1028,7 +1209,8 @@ def test_write_with_multi_bytes_and_MindDataset(): | |||
| assert len(item) == 2 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| assert (item[field] == | |||
| data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| @@ -1045,7 +1227,8 @@ def test_write_with_multi_bytes_and_MindDataset(): | |||
| assert len(item) == 2 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| assert (item[field] == | |||
| data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| @@ -1062,7 +1245,8 @@ def test_write_with_multi_bytes_and_MindDataset(): | |||
| assert len(item) == 3 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| assert (item[field] == | |||
| data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| @@ -1070,7 +1254,8 @@ def test_write_with_multi_bytes_and_MindDataset(): | |||
| num_readers = 2 | |||
| data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | |||
| columns_list=["image4", "image5", "image2", "image3", "file_name"], | |||
| columns_list=["image4", "image5", | |||
| "image2", "image3", "file_name"], | |||
| num_parallel_workers=num_readers, | |||
| shuffle=False) | |||
| assert data_set.get_dataset_size() == 6 | |||
| @@ -1079,7 +1264,8 @@ def test_write_with_multi_bytes_and_MindDataset(): | |||
| assert len(item) == 5 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| assert (item[field] == | |||
| data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| @@ -1177,7 +1363,8 @@ def test_write_with_multi_array_and_MindDataset(): | |||
| assert len(item) == 8 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| assert (item[field] == | |||
| data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| @@ -1196,7 +1383,8 @@ def test_write_with_multi_array_and_MindDataset(): | |||
| assert len(item) == 6 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| assert (item[field] == | |||
| data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| @@ -1215,7 +1403,8 @@ def test_write_with_multi_array_and_MindDataset(): | |||
| assert len(item) == 3 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| assert (item[field] == | |||
| data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| @@ -1234,7 +1423,8 @@ def test_write_with_multi_array_and_MindDataset(): | |||
| assert len(item) == 3 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| assert (item[field] == | |||
| data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| @@ -1251,7 +1441,8 @@ def test_write_with_multi_array_and_MindDataset(): | |||
| assert len(item) == 1 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| assert (item[field] == | |||
| data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| @@ -1271,7 +1462,8 @@ def test_write_with_multi_array_and_MindDataset(): | |||
| assert len(item) == 8 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| assert (item[field] == | |||
| data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| @@ -25,8 +25,24 @@ from mindspore.mindrecord import SUCCESS | |||
| CIFAR100_DIR = "../data/mindrecord/testCifar100Data" | |||
| MINDRECORD_FILE = "./cifar100.mindrecord" | |||
| def test_cifar100_to_mindrecord_without_index_fields(): | |||
| @pytest.fixture | |||
| def fixture_file(): | |||
| """add/remove file""" | |||
| def remove_file(x): | |||
| if os.path.exists("{}".format(x)): | |||
| os.remove("{}".format(x)) | |||
| if os.path.exists("{}.db".format(x)): | |||
| os.remove("{}.db".format(x)) | |||
| if os.path.exists("{}_test".format(x)): | |||
| os.remove("{}_test".format(x)) | |||
| if os.path.exists("{}_test.db".format(x)): | |||
| os.remove("{}_test.db".format(x)) | |||
| remove_file(MINDRECORD_FILE) | |||
| yield "yield_fixture_data" | |||
| remove_file(MINDRECORD_FILE) | |||
| def test_cifar100_to_mindrecord_without_index_fields(fixture_file): | |||
| """test transform cifar100 dataset to mindrecord without index fields.""" | |||
| cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, MINDRECORD_FILE) | |||
| ret = cifar100_transformer.transform() | |||
| @@ -34,25 +50,14 @@ def test_cifar100_to_mindrecord_without_index_fields(): | |||
| assert os.path.exists(MINDRECORD_FILE) | |||
| assert os.path.exists(MINDRECORD_FILE + "_test") | |||
| read() | |||
| os.remove("{}".format(MINDRECORD_FILE)) | |||
| os.remove("{}.db".format(MINDRECORD_FILE)) | |||
| os.remove("{}".format(MINDRECORD_FILE + "_test")) | |||
| os.remove("{}.db".format(MINDRECORD_FILE + "_test")) | |||
| def test_cifar100_to_mindrecord(): | |||
| def test_cifar100_to_mindrecord(fixture_file): | |||
| """test transform cifar100 dataset to mindrecord.""" | |||
| cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, MINDRECORD_FILE) | |||
| cifar100_transformer.transform(['fine_label', 'coarse_label']) | |||
| assert os.path.exists(MINDRECORD_FILE) | |||
| assert os.path.exists(MINDRECORD_FILE + "_test") | |||
| read() | |||
| os.remove("{}".format(MINDRECORD_FILE)) | |||
| os.remove("{}.db".format(MINDRECORD_FILE)) | |||
| os.remove("{}".format(MINDRECORD_FILE + "_test")) | |||
| os.remove("{}.db".format(MINDRECORD_FILE + "_test")) | |||
| def read(): | |||
| @@ -77,8 +82,7 @@ def read(): | |||
| assert count == 4 | |||
| reader.close() | |||
| def test_cifar100_to_mindrecord_illegal_file_name(): | |||
| def test_cifar100_to_mindrecord_illegal_file_name(fixture_file): | |||
| """ | |||
| test transform cifar100 dataset to mindrecord | |||
| when file name contains illegal character. | |||
| @@ -88,8 +92,7 @@ def test_cifar100_to_mindrecord_illegal_file_name(): | |||
| cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, filename) | |||
| cifar100_transformer.transform() | |||
| def test_cifar100_to_mindrecord_filename_start_with_space(): | |||
| def test_cifar100_to_mindrecord_filename_start_with_space(fixture_file): | |||
| """ | |||
| test transform cifar10 dataset to mindrecord | |||
| when file name starts with space. | |||
| @@ -100,8 +103,7 @@ def test_cifar100_to_mindrecord_filename_start_with_space(): | |||
| cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, filename) | |||
| cifar100_transformer.transform() | |||
| def test_cifar100_to_mindrecord_filename_contain_space(): | |||
| def test_cifar100_to_mindrecord_filename_contain_space(fixture_file): | |||
| """ | |||
| test transform cifar10 dataset to mindrecord | |||
| when file name contains space. | |||
| @@ -111,14 +113,8 @@ def test_cifar100_to_mindrecord_filename_contain_space(): | |||
| cifar100_transformer.transform() | |||
| assert os.path.exists(filename) | |||
| assert os.path.exists(filename + "_test") | |||
| os.remove("{}".format(filename)) | |||
| os.remove("{}.db".format(filename)) | |||
| os.remove("{}".format(filename + "_test")) | |||
| os.remove("{}.db".format(filename + "_test")) | |||
| def test_cifar100_to_mindrecord_directory(): | |||
| def test_cifar100_to_mindrecord_directory(fixture_file): | |||
| """ | |||
| test transform cifar10 dataset to mindrecord | |||
| when destination path is directory. | |||
| @@ -129,8 +125,7 @@ def test_cifar100_to_mindrecord_directory(): | |||
| CIFAR100_DIR) | |||
| cifar100_transformer.transform() | |||
| def test_cifar100_to_mindrecord_filename_equals_cifar100(): | |||
| def test_cifar100_to_mindrecord_filename_equals_cifar100(fixture_file): | |||
| """ | |||
| test transform cifar10 dataset to mindrecord | |||
| when destination path equals source path. | |||
| @@ -24,36 +24,60 @@ from mindspore.mindrecord import MRMOpenError, SUCCESS | |||
| CIFAR10_DIR = "../data/mindrecord/testCifar10Data" | |||
| MINDRECORD_FILE = "./cifar10.mindrecord" | |||
| def test_cifar10_to_mindrecord_without_index_fields(): | |||
| @pytest.fixture | |||
| def fixture_file(): | |||
| """add/remove file""" | |||
| def remove_file(x): | |||
| if os.path.exists("{}".format(x)): | |||
| os.remove("{}".format(x)) | |||
| if os.path.exists("{}.db".format(x)): | |||
| os.remove("{}.db".format(x)) | |||
| if os.path.exists("{}_test".format(x)): | |||
| os.remove("{}_test".format(x)) | |||
| if os.path.exists("{}_test.db".format(x)): | |||
| os.remove("{}_test.db".format(x)) | |||
| remove_file(MINDRECORD_FILE) | |||
| yield "yield_fixture_data" | |||
| remove_file(MINDRECORD_FILE) | |||
| @pytest.fixture | |||
| def fixture_space_file(): | |||
| """add/remove file""" | |||
| def remove_file(x): | |||
| if os.path.exists("{}".format(x)): | |||
| os.remove("{}".format(x)) | |||
| if os.path.exists("{}.db".format(x)): | |||
| os.remove("{}.db".format(x)) | |||
| if os.path.exists("{}_test".format(x)): | |||
| os.remove("{}_test".format(x)) | |||
| if os.path.exists("{}_test.db".format(x)): | |||
| os.remove("{}_test.db".format(x)) | |||
| x = "./yes ok" | |||
| remove_file(x) | |||
| yield "yield_fixture_data" | |||
| remove_file(x) | |||
| def test_cifar10_to_mindrecord_without_index_fields(fixture_file): | |||
| """test transform cifar10 dataset to mindrecord without index fields.""" | |||
| cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE) | |||
| cifar10_transformer.transform() | |||
| assert os.path.exists(MINDRECORD_FILE) | |||
| assert os.path.exists(MINDRECORD_FILE + "_test") | |||
| read() | |||
| os.remove("{}".format(MINDRECORD_FILE)) | |||
| os.remove("{}.db".format(MINDRECORD_FILE)) | |||
| os.remove("{}".format(MINDRECORD_FILE + "_test")) | |||
| os.remove("{}.db".format(MINDRECORD_FILE + "_test")) | |||
| def test_cifar10_to_mindrecord(): | |||
| def test_cifar10_to_mindrecord(fixture_file): | |||
| """test transform cifar10 dataset to mindrecord.""" | |||
| cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE) | |||
| cifar10_transformer.transform(['label']) | |||
| assert os.path.exists(MINDRECORD_FILE) | |||
| assert os.path.exists(MINDRECORD_FILE + "_test") | |||
| read() | |||
| os.remove("{}".format(MINDRECORD_FILE)) | |||
| os.remove("{}.db".format(MINDRECORD_FILE)) | |||
| os.remove("{}".format(MINDRECORD_FILE + "_test")) | |||
| os.remove("{}.db".format(MINDRECORD_FILE + "_test")) | |||
| def test_cifar10_to_mindrecord_with_return(): | |||
| def test_cifar10_to_mindrecord_with_return(fixture_file): | |||
| """test transform cifar10 dataset to mindrecord.""" | |||
| cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE) | |||
| ret = cifar10_transformer.transform(['label']) | |||
| @@ -61,11 +85,6 @@ def test_cifar10_to_mindrecord_with_return(): | |||
| assert os.path.exists(MINDRECORD_FILE) | |||
| assert os.path.exists(MINDRECORD_FILE + "_test") | |||
| read() | |||
| os.remove("{}".format(MINDRECORD_FILE)) | |||
| os.remove("{}.db".format(MINDRECORD_FILE)) | |||
| os.remove("{}".format(MINDRECORD_FILE + "_test")) | |||
| os.remove("{}.db".format(MINDRECORD_FILE + "_test")) | |||
| def read(): | |||
| @@ -90,8 +109,7 @@ def read(): | |||
| assert count == 4 | |||
| reader.close() | |||
| def test_cifar10_to_mindrecord_illegal_file_name(): | |||
| def test_cifar10_to_mindrecord_illegal_file_name(fixture_file): | |||
| """ | |||
| test transform cifar10 dataset to mindrecord | |||
| when file name contains illegal character. | |||
| @@ -101,8 +119,7 @@ def test_cifar10_to_mindrecord_illegal_file_name(): | |||
| cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, filename) | |||
| cifar10_transformer.transform() | |||
| def test_cifar10_to_mindrecord_filename_start_with_space(): | |||
| def test_cifar10_to_mindrecord_filename_start_with_space(fixture_file): | |||
| """ | |||
| test transform cifar10 dataset to mindrecord | |||
| when file name starts with space. | |||
| @@ -113,8 +130,7 @@ def test_cifar10_to_mindrecord_filename_start_with_space(): | |||
| cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, filename) | |||
| cifar10_transformer.transform() | |||
| def test_cifar10_to_mindrecord_filename_contain_space(): | |||
| def test_cifar10_to_mindrecord_filename_contain_space(fixture_space_file): | |||
| """ | |||
| test transform cifar10 dataset to mindrecord | |||
| when file name contains space. | |||
| @@ -124,14 +140,8 @@ def test_cifar10_to_mindrecord_filename_contain_space(): | |||
| cifar10_transformer.transform() | |||
| assert os.path.exists(filename) | |||
| assert os.path.exists(filename + "_test") | |||
| os.remove("{}".format(filename)) | |||
| os.remove("{}.db".format(filename)) | |||
| os.remove("{}".format(filename + "_test")) | |||
| os.remove("{}.db".format(filename + "_test")) | |||
| def test_cifar10_to_mindrecord_directory(): | |||
| def test_cifar10_to_mindrecord_directory(fixture_file): | |||
| """ | |||
| test transform cifar10 dataset to mindrecord | |||
| when destination path is directory. | |||
| @@ -25,6 +25,26 @@ IMAGENET_IMAGE_DIR = "../data/mindrecord/testImageNetDataWhole/images" | |||
| MINDRECORD_FILE = "../data/mindrecord/testImageNetDataWhole/imagenet.mindrecord" | |||
| PARTITION_NUMBER = 4 | |||
| @pytest.fixture | |||
| def fixture_file(): | |||
| """add/remove file""" | |||
| def remove_one_file(x): | |||
| if os.path.exists(x): | |||
| os.remove(x) | |||
| def remove_file(): | |||
| x = MINDRECORD_FILE | |||
| remove_one_file(x) | |||
| x = MINDRECORD_FILE + ".db" | |||
| remove_one_file(x) | |||
| for i in range(PARTITION_NUMBER): | |||
| x = MINDRECORD_FILE + str(i) | |||
| remove_one_file(x) | |||
| x = MINDRECORD_FILE + str(i) + ".db" | |||
| remove_one_file(x) | |||
| remove_file() | |||
| yield "yield_fixture_data" | |||
| remove_file() | |||
| def read(filename): | |||
| """test file reade""" | |||
| @@ -38,8 +58,7 @@ def read(filename): | |||
| assert count == 20 | |||
| reader.close() | |||
| def test_imagenet_to_mindrecord(): | |||
| def test_imagenet_to_mindrecord(fixture_file): | |||
| """test transform imagenet dataset to mindrecord.""" | |||
| imagenet_transformer = ImageNetToMR(IMAGENET_MAP_FILE, IMAGENET_IMAGE_DIR, | |||
| MINDRECORD_FILE, PARTITION_NUMBER) | |||
| @@ -48,12 +67,8 @@ def test_imagenet_to_mindrecord(): | |||
| assert os.path.exists(MINDRECORD_FILE + str(i)) | |||
| assert os.path.exists(MINDRECORD_FILE + str(i) + ".db") | |||
| read(MINDRECORD_FILE + "0") | |||
| for i in range(PARTITION_NUMBER): | |||
| os.remove(MINDRECORD_FILE + str(i)) | |||
| os.remove(MINDRECORD_FILE + str(i) + ".db") | |||
| def test_imagenet_to_mindrecord_default_partition_number(): | |||
| def test_imagenet_to_mindrecord_default_partition_number(fixture_file): | |||
| """ | |||
| test transform imagenet dataset to mindrecord | |||
| when partition number is default. | |||
| @@ -64,11 +79,8 @@ def test_imagenet_to_mindrecord_default_partition_number(): | |||
| assert os.path.exists(MINDRECORD_FILE) | |||
| assert os.path.exists(MINDRECORD_FILE + ".db") | |||
| read(MINDRECORD_FILE) | |||
| os.remove("{}".format(MINDRECORD_FILE)) | |||
| os.remove("{}.db".format(MINDRECORD_FILE)) | |||
| def test_imagenet_to_mindrecord_partition_number_0(): | |||
| def test_imagenet_to_mindrecord_partition_number_0(fixture_file): | |||
| """ | |||
| test transform imagenet dataset to mindrecord | |||
| when partition number is 0. | |||
| @@ -79,8 +91,7 @@ def test_imagenet_to_mindrecord_partition_number_0(): | |||
| MINDRECORD_FILE, 0) | |||
| imagenet_transformer.transform() | |||
| def test_imagenet_to_mindrecord_partition_number_none(): | |||
| def test_imagenet_to_mindrecord_partition_number_none(fixture_file): | |||
| """ | |||
| test transform imagenet dataset to mindrecord | |||
| when partition number is none. | |||
| @@ -92,8 +103,7 @@ def test_imagenet_to_mindrecord_partition_number_none(): | |||
| MINDRECORD_FILE, None) | |||
| imagenet_transformer.transform() | |||
| def test_imagenet_to_mindrecord_illegal_filename(): | |||
| def test_imagenet_to_mindrecord_illegal_filename(fixture_file): | |||
| """ | |||
| test transform imagenet dataset to mindrecord | |||
| when file name contains illegal character. | |||
| @@ -26,6 +26,34 @@ CV_FILE_NAME = "./imagenet.mindrecord" | |||
| NLP_FILE_NAME = "./aclImdb.mindrecord" | |||
| FILES_NUM = 4 | |||
| def remove_one_file(x): | |||
| if os.path.exists(x): | |||
| os.remove(x) | |||
| def remove_file(file_name): | |||
| x = file_name | |||
| remove_one_file(x) | |||
| x = file_name + ".db" | |||
| remove_one_file(x) | |||
| for i in range(FILES_NUM): | |||
| x = file_name + str(i) | |||
| remove_one_file(x) | |||
| x = file_name + str(i) + ".db" | |||
| remove_one_file(x) | |||
| @pytest.fixture | |||
| def fixture_cv_file(): | |||
| """add/remove file""" | |||
| remove_file(CV_FILE_NAME) | |||
| yield "yield_fixture_data" | |||
| remove_file(CV_FILE_NAME) | |||
| @pytest.fixture | |||
| def fixture_nlp_file(): | |||
| """add/remove file""" | |||
| remove_file(NLP_FILE_NAME) | |||
| yield "yield_fixture_data" | |||
| remove_file(NLP_FILE_NAME) | |||
| def test_cv_file_writer_shard_num_none(): | |||
| """test cv file writer when shard num is None.""" | |||
| @@ -83,8 +111,7 @@ def test_lack_partition_and_db(): | |||
| 'error_msg: MindRecord File could not open successfully.' \ | |||
| in str(err.value) | |||
| def test_lack_db(): | |||
| def test_lack_db(fixture_cv_file): | |||
| """test file reader when db file does not exist.""" | |||
| create_cv_mindrecord(1) | |||
| os.remove("{}.db".format(CV_FILE_NAME)) | |||
| @@ -94,10 +121,8 @@ def test_lack_db(): | |||
| assert '[MRMOpenError]: error_code: 1347690596, ' \ | |||
| 'error_msg: MindRecord File could not open successfully.' \ | |||
| in str(err.value) | |||
| os.remove(CV_FILE_NAME) | |||
| def test_lack_some_partition_and_db(): | |||
| def test_lack_some_partition_and_db(fixture_cv_file): | |||
| """test file reader when some partition and db do not exist.""" | |||
| create_cv_mindrecord(4) | |||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | |||
| @@ -110,16 +135,8 @@ def test_lack_some_partition_and_db(): | |||
| assert '[MRMOpenError]: error_code: 1347690596, ' \ | |||
| 'error_msg: MindRecord File could not open successfully.' \ | |||
| in str(err.value) | |||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | |||
| for x in range(FILES_NUM)] | |||
| for x in paths: | |||
| if os.path.exists("{}".format(x)): | |||
| os.remove("{}".format(x)) | |||
| if os.path.exists("{}.db".format(x)): | |||
| os.remove("{}.db".format(x)) | |||
| def test_lack_some_partition_first(): | |||
| def test_lack_some_partition_first(fixture_cv_file): | |||
| """test file reader when first partition does not exist.""" | |||
| create_cv_mindrecord(4) | |||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | |||
| @@ -131,14 +148,8 @@ def test_lack_some_partition_first(): | |||
| assert '[MRMOpenError]: error_code: 1347690596, ' \ | |||
| 'error_msg: MindRecord File could not open successfully.' \ | |||
| in str(err.value) | |||
| for x in paths: | |||
| if os.path.exists("{}".format(x)): | |||
| os.remove("{}".format(x)) | |||
| if os.path.exists("{}.db".format(x)): | |||
| os.remove("{}.db".format(x)) | |||
| def test_lack_some_partition_middle(): | |||
| def test_lack_some_partition_middle(fixture_cv_file): | |||
| """test file reader when some partition does not exist.""" | |||
| create_cv_mindrecord(4) | |||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | |||
| @@ -150,14 +161,8 @@ def test_lack_some_partition_middle(): | |||
| assert '[MRMOpenError]: error_code: 1347690596, ' \ | |||
| 'error_msg: MindRecord File could not open successfully.' \ | |||
| in str(err.value) | |||
| for x in paths: | |||
| if os.path.exists("{}".format(x)): | |||
| os.remove("{}".format(x)) | |||
| if os.path.exists("{}.db".format(x)): | |||
| os.remove("{}.db".format(x)) | |||
| def test_lack_some_partition_last(): | |||
| def test_lack_some_partition_last(fixture_cv_file): | |||
| """test file reader when last partition does not exist.""" | |||
| create_cv_mindrecord(4) | |||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | |||
| @@ -169,14 +174,8 @@ def test_lack_some_partition_last(): | |||
| assert '[MRMOpenError]: error_code: 1347690596, ' \ | |||
| 'error_msg: MindRecord File could not open successfully.' \ | |||
| in str(err.value) | |||
| for x in paths: | |||
| if os.path.exists("{}".format(x)): | |||
| os.remove("{}".format(x)) | |||
| if os.path.exists("{}.db".format(x)): | |||
| os.remove("{}.db".format(x)) | |||
| def test_mindpage_lack_some_partition(): | |||
| def test_mindpage_lack_some_partition(fixture_cv_file): | |||
| """test page reader when some partition does not exist.""" | |||
| create_cv_mindrecord(4) | |||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | |||
| @@ -187,14 +186,8 @@ def test_mindpage_lack_some_partition(): | |||
| assert '[MRMOpenError]: error_code: 1347690596, ' \ | |||
| 'error_msg: MindRecord File could not open successfully.' \ | |||
| in str(err.value) | |||
| for x in paths: | |||
| if os.path.exists("{}".format(x)): | |||
| os.remove("{}".format(x)) | |||
| if os.path.exists("{}.db".format(x)): | |||
| os.remove("{}.db".format(x)) | |||
| def test_lack_some_db(): | |||
| def test_lack_some_db(fixture_cv_file): | |||
| """test file reader when some db does not exist.""" | |||
| create_cv_mindrecord(4) | |||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | |||
| @@ -206,11 +199,6 @@ def test_lack_some_db(): | |||
| assert '[MRMOpenError]: error_code: 1347690596, ' \ | |||
| 'error_msg: MindRecord File could not open successfully.' \ | |||
| in str(err.value) | |||
| for x in paths: | |||
| if os.path.exists("{}".format(x)): | |||
| os.remove("{}".format(x)) | |||
| if os.path.exists("{}.db".format(x)): | |||
| os.remove("{}.db".format(x)) | |||
| def test_invalid_mindrecord(): | |||
| @@ -225,8 +213,7 @@ def test_invalid_mindrecord(): | |||
| in str(err.value) | |||
| os.remove(CV_FILE_NAME) | |||
| def test_invalid_db(): | |||
| def test_invalid_db(fixture_cv_file): | |||
| """test file reader when the content of db is illegal.""" | |||
| create_cv_mindrecord(1) | |||
| os.remove("imagenet.mindrecord.db") | |||
| @@ -237,11 +224,8 @@ def test_invalid_db(): | |||
| assert '[MRMOpenError]: error_code: 1347690596, ' \ | |||
| 'error_msg: MindRecord File could not open successfully.' \ | |||
| in str(err.value) | |||
| os.remove("imagenet.mindrecord") | |||
| os.remove("imagenet.mindrecord.db") | |||
| def test_overwrite_invalid_mindrecord(): | |||
| def test_overwrite_invalid_mindrecord(fixture_cv_file): | |||
| """test file writer when overwrite invalid mindreocrd file.""" | |||
| with open(CV_FILE_NAME, 'w') as f: | |||
| f.write('just for test') | |||
| @@ -250,10 +234,8 @@ def test_overwrite_invalid_mindrecord(): | |||
| assert '[MRMOpenError]: error_code: 1347690596, ' \ | |||
| 'error_msg: MindRecord File could not open successfully.' \ | |||
| in str(err.value) | |||
| os.remove(CV_FILE_NAME) | |||
| def test_overwrite_invalid_db(): | |||
| def test_overwrite_invalid_db(fixture_cv_file): | |||
| """test file writer when overwrite invalid db file.""" | |||
| with open('imagenet.mindrecord.db', 'w') as f: | |||
| f.write('just for test') | |||
| @@ -261,11 +243,8 @@ def test_overwrite_invalid_db(): | |||
| create_cv_mindrecord(1) | |||
| assert '[MRMGenerateIndexError]: error_code: 1347690612, ' \ | |||
| 'error_msg: Failed to generate index.' in str(err.value) | |||
| os.remove("imagenet.mindrecord") | |||
| os.remove("imagenet.mindrecord.db") | |||
| def test_read_after_close(): | |||
| def test_read_after_close(fixture_cv_file): | |||
| """test file reader when close read.""" | |||
| create_cv_mindrecord(1) | |||
| reader = FileReader(CV_FILE_NAME) | |||
| @@ -275,11 +254,8 @@ def test_read_after_close(): | |||
| count = count + 1 | |||
| logger.info("#item{}: {}".format(index, x)) | |||
| assert count == 0 | |||
| os.remove(CV_FILE_NAME) | |||
| os.remove("{}.db".format(CV_FILE_NAME)) | |||
| def test_file_read_after_read(): | |||
| def test_file_read_after_read(fixture_cv_file): | |||
| """test file reader when finish read.""" | |||
| create_cv_mindrecord(1) | |||
| reader = FileReader(CV_FILE_NAME) | |||
| @@ -295,8 +271,6 @@ def test_file_read_after_read(): | |||
| cnt = cnt + 1 | |||
| logger.info("#item{}: {}".format(index, x)) | |||
| assert cnt == 0 | |||
| os.remove(CV_FILE_NAME) | |||
| os.remove("{}.db".format(CV_FILE_NAME)) | |||
| def test_cv_file_writer_shard_num_greater_than_1000(): | |||
| @@ -312,8 +286,7 @@ def test_add_index_without_add_schema(): | |||
| fw.add_index(["label"]) | |||
| assert 'Failed to get meta info' in str(err.value) | |||
| def test_mindpage_pageno_pagesize_not_int(): | |||
| def test_mindpage_pageno_pagesize_not_int(fixture_cv_file): | |||
| """test page reader when some partition does not exist.""" | |||
| create_cv_mindrecord(4) | |||
| reader = MindPage(CV_FILE_NAME + "0") | |||
| @@ -342,14 +315,8 @@ def test_mindpage_pageno_pagesize_not_int(): | |||
| with pytest.raises(MRMFetchDataError, match="Failed to fetch data by category."): | |||
| reader.read_at_page_by_id(99999, 0, 1) | |||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | |||
| for x in range(FILES_NUM)] | |||
| for x in paths: | |||
| os.remove("{}".format(x)) | |||
| os.remove("{}.db".format(x)) | |||
| def test_mindpage_filename_not_exist(): | |||
| def test_mindpage_filename_not_exist(fixture_cv_file): | |||
| """test page reader when some partition does not exist.""" | |||
| create_cv_mindrecord(4) | |||
| reader = MindPage(CV_FILE_NAME + "0") | |||
| @@ -374,6 +341,3 @@ def test_mindpage_filename_not_exist(): | |||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | |||
| for x in range(FILES_NUM)] | |||
| for x in paths: | |||
| os.remove("{}".format(x)) | |||
| os.remove("{}.db".format(x)) | |||
| @@ -14,6 +14,7 @@ | |||
| """test mnist to mindrecord tool""" | |||
| import cv2 | |||
| import gzip | |||
| import pytest | |||
| import numpy as np | |||
| import os | |||
| @@ -27,6 +28,34 @@ PARTITION_NUM = 4 | |||
| IMAGE_SIZE = 28 | |||
| NUM_CHANNELS = 1 | |||
| @pytest.fixture | |||
| def fixture_file(): | |||
| """add/remove file""" | |||
| def remove_one_file(x): | |||
| if os.path.exists(x): | |||
| os.remove(x) | |||
| def remove_file(): | |||
| x = "mnist_train.mindrecord" | |||
| remove_one_file(x) | |||
| x = "mnist_train.mindrecord.db" | |||
| remove_one_file(x) | |||
| x = "mnist_test.mindrecord" | |||
| remove_one_file(x) | |||
| x = "mnist_test.mindrecord.db" | |||
| remove_one_file(x) | |||
| for i in range(PARTITION_NUM): | |||
| x = "mnist_train.mindrecord" + str(i) | |||
| remove_one_file(x) | |||
| x = "mnist_train.mindrecord" + str(i) + ".db" | |||
| remove_one_file(x) | |||
| x = "mnist_test.mindrecord" + str(i) | |||
| remove_one_file(x) | |||
| x = "mnist_test.mindrecord" + str(i) + ".db" | |||
| remove_one_file(x) | |||
| remove_file() | |||
| yield "yield_fixture_data" | |||
| remove_file() | |||
| def read(train_name, test_name): | |||
| """test file reader""" | |||
| @@ -51,7 +80,7 @@ def read(train_name, test_name): | |||
| reader.close() | |||
| def test_mnist_to_mindrecord(): | |||
| def test_mnist_to_mindrecord(fixture_file): | |||
| """test transform mnist dataset to mindrecord.""" | |||
| mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME) | |||
| mnist_transformer.transform() | |||
| @@ -60,13 +89,7 @@ def test_mnist_to_mindrecord(): | |||
| read("mnist_train.mindrecord", "mnist_test.mindrecord") | |||
| os.remove("{}".format("mnist_train.mindrecord")) | |||
| os.remove("{}.db".format("mnist_train.mindrecord")) | |||
| os.remove("{}".format("mnist_test.mindrecord")) | |||
| os.remove("{}.db".format("mnist_test.mindrecord")) | |||
| def test_mnist_to_mindrecord_compare_data(): | |||
| def test_mnist_to_mindrecord_compare_data(fixture_file): | |||
| """test transform mnist dataset to mindrecord and compare data.""" | |||
| mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME) | |||
| mnist_transformer.transform() | |||
| @@ -121,21 +144,10 @@ def test_mnist_to_mindrecord_compare_data(): | |||
| assert np.array(x['label']) == label | |||
| reader.close() | |||
| os.remove("{}".format("mnist_train.mindrecord")) | |||
| os.remove("{}.db".format("mnist_train.mindrecord")) | |||
| os.remove("{}".format("mnist_test.mindrecord")) | |||
| os.remove("{}.db".format("mnist_test.mindrecord")) | |||
| def test_mnist_to_mindrecord_multi_partition(): | |||
| def test_mnist_to_mindrecord_multi_partition(fixture_file): | |||
| """test transform mnist dataset to multiple mindrecord files.""" | |||
| mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME, PARTITION_NUM) | |||
| mnist_transformer.transform() | |||
| read("mnist_train.mindrecord0", "mnist_test.mindrecord0") | |||
| for i in range(PARTITION_NUM): | |||
| os.remove("{}".format("mnist_train.mindrecord" + str(i))) | |||
| os.remove("{}.db".format("mnist_train.mindrecord" + str(i))) | |||
| os.remove("{}".format("mnist_test.mindrecord" + str(i))) | |||
| os.remove("{}.db".format("mnist_test.mindrecord" + str(i))) | |||