| @@ -165,12 +165,22 @@ Status MindRecordOp::Init() { | |||
| Status MindRecordOp::SetColumnsBlob() { | |||
| columns_blob_ = shard_reader_->get_blob_fields().second; | |||
| // get the exactly blob fields by columns_to_load_ | |||
| std::vector<std::string> columns_blob_exact; | |||
| for (auto &blob_field : columns_blob_) { | |||
| for (auto &column : columns_to_load_) { | |||
| if (column.compare(blob_field) == 0) { | |||
| columns_blob_exact.push_back(blob_field); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| columns_blob_index_ = std::vector<int32_t>(columns_to_load_.size(), -1); | |||
| int32_t iBlob = 0; | |||
| for (uint32_t i = 0; i < columns_blob_.size(); ++i) { | |||
| if (column_name_mapping_.count(columns_blob_[i])) { | |||
| columns_blob_index_[column_name_mapping_[columns_blob_[i]]] = iBlob++; | |||
| } | |||
| for (auto &blob_exact : columns_blob_exact) { | |||
| columns_blob_index_[column_name_mapping_[blob_exact]] = iBlob++; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -294,6 +294,10 @@ class ShardReader { | |||
| /// \brief get number of classes | |||
| int64_t GetNumClasses(const std::string &file_path, const std::string &category_field); | |||
| /// \brief get exactly blob fields data by indices | |||
| std::vector<uint8_t> ExtractBlobFieldBySelectColumns(std::vector<uint8_t> &blob_fields_bytes, | |||
| std::vector<uint32_t> &ordered_selected_columns_index); | |||
| protected: | |||
| uint64_t header_size_; // header size | |||
| uint64_t page_size_; // page size | |||
| @@ -790,6 +790,8 @@ MSRStatus ShardReader::Open(const std::string &file_path, int n_consumer, | |||
| n_consumer = kMinConsumerCount; | |||
| } | |||
| CheckNlp(); | |||
| // dead code | |||
| if (nlp_) { | |||
| selected_columns_ = selected_columns; | |||
| } else { | |||
| @@ -801,6 +803,7 @@ MSRStatus ShardReader::Open(const std::string &file_path, int n_consumer, | |||
| } | |||
| } | |||
| } | |||
| selected_columns_ = selected_columns; | |||
| if (CheckColumnList(selected_columns_) == FAILED) { | |||
| MS_LOG(ERROR) << "Illegal column list"; | |||
| @@ -1060,6 +1063,36 @@ MSRStatus ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, u | |||
| return SUCCESS; | |||
| } | |||
| std::vector<uint8_t> ShardReader::ExtractBlobFieldBySelectColumns( | |||
| std::vector<uint8_t> &blob_fields_bytes, std::vector<uint32_t> &ordered_selected_columns_index) { | |||
| std::vector<uint8_t> exactly_blob_fields_bytes; | |||
| auto uint64_from_bytes = [&](int64_t pos) { | |||
| uint64_t result = 0; | |||
| for (uint64_t n = 0; n < kInt64Len; n++) { | |||
| result = (result << 8) + blob_fields_bytes[pos + n]; | |||
| } | |||
| return result; | |||
| }; | |||
| // get the exactly blob fields | |||
| uint32_t current_index = 0; | |||
| uint64_t current_offset = 0; | |||
| uint64_t data_len = uint64_from_bytes(current_offset); | |||
| while (current_offset < blob_fields_bytes.size()) { | |||
| if (std::any_of(ordered_selected_columns_index.begin(), ordered_selected_columns_index.end(), | |||
| [¤t_index](uint32_t &index) { return index == current_index; })) { | |||
| exactly_blob_fields_bytes.insert(exactly_blob_fields_bytes.end(), blob_fields_bytes.begin() + current_offset, | |||
| blob_fields_bytes.begin() + current_offset + kInt64Len + data_len); | |||
| } | |||
| current_index++; | |||
| current_offset += kInt64Len + data_len; | |||
| data_len = uint64_from_bytes(current_offset); | |||
| } | |||
| return exactly_blob_fields_bytes; | |||
| } | |||
| TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_id) { | |||
| // All tasks are done | |||
| if (task_id >= static_cast<int>(tasks_.Size())) { | |||
| @@ -1077,6 +1110,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_ | |||
| return std::make_pair(FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>()); | |||
| } | |||
| const std::shared_ptr<Page> &page = ret.second; | |||
| // Pack image list | |||
| std::vector<uint8_t> images(addr[1] - addr[0]); | |||
| auto file_offset = header_size_ + page_size_ * (page->get_page_id()) + addr[0]; | |||
| @@ -1096,10 +1130,42 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_ | |||
| return std::make_pair(FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>()); | |||
| } | |||
| // extract the exactly blob bytes by selected columns | |||
| std::vector<uint8_t> images_with_exact_columns; | |||
| if (selected_columns_.size() == 0) { | |||
| images_with_exact_columns = images; | |||
| } else { | |||
| auto blob_fields = get_blob_fields(); | |||
| std::vector<uint32_t> ordered_selected_columns_index; | |||
| uint32_t index = 0; | |||
| for (auto &blob_field : blob_fields.second) { | |||
| for (auto &field : selected_columns_) { | |||
| if (field.compare(blob_field) == 0) { | |||
| ordered_selected_columns_index.push_back(index); | |||
| break; | |||
| } | |||
| } | |||
| index++; | |||
| } | |||
| if (ordered_selected_columns_index.size() != 0) { | |||
| // extract the images | |||
| if (blob_fields.second.size() == 1) { | |||
| if (ordered_selected_columns_index.size() == 1) { | |||
| images_with_exact_columns = images; | |||
| } | |||
| } else { | |||
| images_with_exact_columns = ExtractBlobFieldBySelectColumns(images, ordered_selected_columns_index); | |||
| } | |||
| } | |||
| } | |||
| // Deliver batch data to output map | |||
| std::vector<std::tuple<std::vector<uint8_t>, json>> batch; | |||
| if (nlp_) { | |||
| json blob_fields = json::from_msgpack(images); | |||
| // dead code | |||
| json blob_fields = json::from_msgpack(images_with_exact_columns); | |||
| json merge; | |||
| if (selected_columns_.size() > 0) { | |||
| @@ -1117,7 +1183,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_ | |||
| } | |||
| batch.emplace_back(std::vector<uint8_t>{}, std::move(merge)); | |||
| } else { | |||
| batch.emplace_back(std::move(images), std::move(std::get<2>(task))); | |||
| batch.emplace_back(std::move(images_with_exact_columns), std::move(std::get<2>(task))); | |||
| } | |||
| return std::make_pair(SUCCESS, std::move(batch)); | |||
| } | |||
| @@ -92,15 +92,25 @@ def populate_data(raw, blob, columns, blob_fields, schema): | |||
| if raw: | |||
| # remove dummy fileds | |||
| raw = {k: v for k, v in raw.items() if k in schema} | |||
| else: | |||
| raw = {} | |||
| if not blob_fields: | |||
| return raw | |||
| # Get the order preserving sequence of columns in blob | |||
| ordered_columns = [] | |||
| if columns: | |||
| for blob_field in blob_fields: | |||
| if blob_field in columns: | |||
| ordered_columns.append(blob_field) | |||
| else: | |||
| ordered_columns = blob_fields | |||
| blob_bytes = bytes(blob) | |||
| def _render_raw(field, blob_data): | |||
| data_type = schema[field]['type'] | |||
| data_shape = schema[field]['shape'] if 'shape' in schema[field] else [] | |||
| if columns and field not in columns: | |||
| return | |||
| if data_shape: | |||
| try: | |||
| raw[field] = np.reshape(np.frombuffer(blob_data, dtype=data_type), data_shape) | |||
| @@ -110,7 +120,9 @@ def populate_data(raw, blob, columns, blob_fields, schema): | |||
| raw[field] = blob_data | |||
| if len(blob_fields) == 1: | |||
| _render_raw(blob_fields[0], blob_bytes) | |||
| if len(ordered_columns) == 1: | |||
| _render_raw(blob_fields[0], blob_bytes) | |||
| return raw | |||
| return raw | |||
| def _int_from_bytes(xbytes: bytes) -> int: | |||
| @@ -125,6 +137,6 @@ def populate_data(raw, blob, columns, blob_fields, schema): | |||
| start += 8 | |||
| return blob_bytes[start : start + n_bytes] | |||
| for i, blob_field in enumerate(blob_fields): | |||
| for i, blob_field in enumerate(ordered_columns): | |||
| _render_raw(blob_field, _blob_at_position(i)) | |||
| return raw | |||
| @@ -545,3 +545,597 @@ def inputs(vectors, maxlen=50): | |||
| mask = [1]*length + [0]*(maxlen-length) | |||
| segment = [0]*maxlen | |||
| return input_, mask, segment | |||
| def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): | |||
| mindrecord_file_name = "test.mindrecord" | |||
| data = [{"file_name": "001.jpg", "label": 4, | |||
| "image1": bytes("image1 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image1 bytes def", encoding='UTF-8'), | |||
| "source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int64), | |||
| "source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64), | |||
| "image3": bytes("image1 bytes ghi", encoding='UTF-8'), | |||
| "image4": bytes("image1 bytes jkl", encoding='UTF-8'), | |||
| "image5": bytes("image1 bytes mno", encoding='UTF-8'), | |||
| "target_sos_ids": np.array([28, 29, 30, 31, 32], dtype=np.int64), | |||
| "target_sos_mask": np.array([33, 34, 35, 36, 37, 38], dtype=np.int64), | |||
| "target_eos_ids": np.array([39, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), | |||
| "target_eos_mask": np.array([48, 49, 50, 51], dtype=np.int64)}, | |||
| {"file_name": "002.jpg", "label": 5, | |||
| "image1": bytes("image2 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image2 bytes def", encoding='UTF-8'), | |||
| "image3": bytes("image2 bytes ghi", encoding='UTF-8'), | |||
| "image4": bytes("image2 bytes jkl", encoding='UTF-8'), | |||
| "image5": bytes("image2 bytes mno", encoding='UTF-8'), | |||
| "source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int64), | |||
| "source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64), | |||
| "target_sos_ids": np.array([128, 29, 30, 31, 32], dtype=np.int64), | |||
| "target_sos_mask": np.array([133, 34, 35, 36, 37, 38], dtype=np.int64), | |||
| "target_eos_ids": np.array([139, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), | |||
| "target_eos_mask": np.array([148, 49, 50, 51], dtype=np.int64)}, | |||
| {"file_name": "003.jpg", "label": 6, | |||
| "source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int64), | |||
| "source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64), | |||
| "target_sos_ids": np.array([228, 29, 30, 31, 32], dtype=np.int64), | |||
| "target_sos_mask": np.array([233, 34, 35, 36, 37, 38], dtype=np.int64), | |||
| "target_eos_ids": np.array([239, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), | |||
| "image1": bytes("image3 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image3 bytes def", encoding='UTF-8'), | |||
| "image3": bytes("image3 bytes ghi", encoding='UTF-8'), | |||
| "image4": bytes("image3 bytes jkl", encoding='UTF-8'), | |||
| "image5": bytes("image3 bytes mno", encoding='UTF-8'), | |||
| "target_eos_mask": np.array([248, 49, 50, 51], dtype=np.int64)}, | |||
| {"file_name": "004.jpg", "label": 7, | |||
| "source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int64), | |||
| "source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64), | |||
| "image1": bytes("image4 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image4 bytes def", encoding='UTF-8'), | |||
| "image3": bytes("image4 bytes ghi", encoding='UTF-8'), | |||
| "image4": bytes("image4 bytes jkl", encoding='UTF-8'), | |||
| "image5": bytes("image4 bytes mno", encoding='UTF-8'), | |||
| "target_sos_ids": np.array([328, 29, 30, 31, 32], dtype=np.int64), | |||
| "target_sos_mask": np.array([333, 34, 35, 36, 37, 38], dtype=np.int64), | |||
| "target_eos_ids": np.array([339, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), | |||
| "target_eos_mask": np.array([348, 49, 50, 51], dtype=np.int64)}, | |||
| {"file_name": "005.jpg", "label": 8, | |||
| "source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int64), | |||
| "source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64), | |||
| "target_sos_ids": np.array([428, 29, 30, 31, 32], dtype=np.int64), | |||
| "target_sos_mask": np.array([433, 34, 35, 36, 37, 38], dtype=np.int64), | |||
| "image1": bytes("image5 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image5 bytes def", encoding='UTF-8'), | |||
| "image3": bytes("image5 bytes ghi", encoding='UTF-8'), | |||
| "image4": bytes("image5 bytes jkl", encoding='UTF-8'), | |||
| "image5": bytes("image5 bytes mno", encoding='UTF-8'), | |||
| "target_eos_ids": np.array([439, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), | |||
| "target_eos_mask": np.array([448, 49, 50, 51], dtype=np.int64)}, | |||
| {"file_name": "006.jpg", "label": 9, | |||
| "source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int64), | |||
| "source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64), | |||
| "target_sos_ids": np.array([528, 29, 30, 31, 32], dtype=np.int64), | |||
| "image1": bytes("image6 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image6 bytes def", encoding='UTF-8'), | |||
| "image3": bytes("image6 bytes ghi", encoding='UTF-8'), | |||
| "image4": bytes("image6 bytes jkl", encoding='UTF-8'), | |||
| "image5": bytes("image6 bytes mno", encoding='UTF-8'), | |||
| "target_sos_mask": np.array([533, 34, 35, 36, 37, 38], dtype=np.int64), | |||
| "target_eos_ids": np.array([539, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), | |||
| "target_eos_mask": np.array([548, 49, 50, 51], dtype=np.int64)} | |||
| ] | |||
| writer = FileWriter(mindrecord_file_name) | |||
| schema = {"file_name": {"type": "string"}, | |||
| "image1": {"type": "bytes"}, | |||
| "image2": {"type": "bytes"}, | |||
| "source_sos_ids": {"type": "int64", "shape": [-1]}, | |||
| "source_sos_mask": {"type": "int64", "shape": [-1]}, | |||
| "image3": {"type": "bytes"}, | |||
| "image4": {"type": "bytes"}, | |||
| "image5": {"type": "bytes"}, | |||
| "target_sos_ids": {"type": "int64", "shape": [-1]}, | |||
| "target_sos_mask": {"type": "int64", "shape": [-1]}, | |||
| "target_eos_ids": {"type": "int64", "shape": [-1]}, | |||
| "target_eos_mask": {"type": "int64", "shape": [-1]}, | |||
| "label": {"type": "int32"}} | |||
| writer.add_schema(schema, "data is so cool") | |||
| writer.write_raw_data(data) | |||
| writer.commit() | |||
| # change data value to list | |||
| data_value_to_list = [] | |||
| for item in data: | |||
| new_data = {} | |||
| new_data['file_name'] = np.asarray(list(bytes(item["file_name"], encoding='utf-8')), dtype=np.uint8) | |||
| new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32) | |||
| new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8) | |||
| new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8) | |||
| new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8) | |||
| new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8) | |||
| new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8) | |||
| new_data['source_sos_ids'] = item["source_sos_ids"] | |||
| new_data['source_sos_mask'] = item["source_sos_mask"] | |||
| new_data['target_sos_ids'] = item["target_sos_ids"] | |||
| new_data['target_sos_mask'] = item["target_sos_mask"] | |||
| new_data['target_eos_ids'] = item["target_eos_ids"] | |||
| new_data['target_eos_mask'] = item["target_eos_mask"] | |||
| data_value_to_list.append(new_data) | |||
| num_readers = 2 | |||
| data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | |||
| num_parallel_workers=num_readers, | |||
| shuffle=False) | |||
| assert data_set.get_dataset_size() == 6 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| assert len(item) == 13 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| assert num_iter == 6 | |||
| num_readers = 2 | |||
| data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | |||
| columns_list=["source_sos_ids", "source_sos_mask", "target_sos_ids"], | |||
| num_parallel_workers=num_readers, | |||
| shuffle=False) | |||
| assert data_set.get_dataset_size() == 6 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| assert len(item) == 3 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data[num_iter][field] | |||
| num_iter += 1 | |||
| assert num_iter == 6 | |||
| num_readers = 1 | |||
| data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | |||
| columns_list=["image2", "source_sos_mask", "image3", "target_sos_ids"], | |||
| num_parallel_workers=num_readers, | |||
| shuffle=False) | |||
| assert data_set.get_dataset_size() == 6 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| assert len(item) == 4 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| assert num_iter == 6 | |||
| num_readers = 3 | |||
| data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | |||
| columns_list=["target_sos_ids", "image4", "source_sos_ids"], | |||
| num_parallel_workers=num_readers, | |||
| shuffle=False) | |||
| assert data_set.get_dataset_size() == 6 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| assert len(item) == 3 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| assert num_iter == 6 | |||
| num_readers = 3 | |||
| data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | |||
| columns_list=["target_sos_ids", "image5", "image4", "image3", "source_sos_ids"], | |||
| num_parallel_workers=num_readers, | |||
| shuffle=False) | |||
| assert data_set.get_dataset_size() == 6 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| assert len(item) == 5 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| assert num_iter == 6 | |||
| num_readers = 1 | |||
| data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | |||
| columns_list=["target_eos_mask", "image5", "image2", "source_sos_mask", "label"], | |||
| num_parallel_workers=num_readers, | |||
| shuffle=False) | |||
| assert data_set.get_dataset_size() == 6 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| assert len(item) == 5 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| assert num_iter == 6 | |||
| num_readers = 2 | |||
| data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | |||
| columns_list=["label", "target_eos_mask", "image1", "target_eos_ids", "source_sos_mask", | |||
| "image2", "image4", "image3", "source_sos_ids", "image5", "file_name"], | |||
| num_parallel_workers=num_readers, | |||
| shuffle=False) | |||
| assert data_set.get_dataset_size() == 6 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| assert len(item) == 11 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| assert num_iter == 6 | |||
| os.remove("{}".format(mindrecord_file_name)) | |||
| os.remove("{}.db".format(mindrecord_file_name)) | |||
| def test_write_with_multi_bytes_and_MindDataset(): | |||
| mindrecord_file_name = "test.mindrecord" | |||
| data = [{"file_name": "001.jpg", "label": 43, | |||
| "image1": bytes("image1 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image1 bytes def", encoding='UTF-8'), | |||
| "image3": bytes("image1 bytes ghi", encoding='UTF-8'), | |||
| "image4": bytes("image1 bytes jkl", encoding='UTF-8'), | |||
| "image5": bytes("image1 bytes mno", encoding='UTF-8')}, | |||
| {"file_name": "002.jpg", "label": 91, | |||
| "image1": bytes("image2 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image2 bytes def", encoding='UTF-8'), | |||
| "image3": bytes("image2 bytes ghi", encoding='UTF-8'), | |||
| "image4": bytes("image2 bytes jkl", encoding='UTF-8'), | |||
| "image5": bytes("image2 bytes mno", encoding='UTF-8')}, | |||
| {"file_name": "003.jpg", "label": 61, | |||
| "image1": bytes("image3 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image3 bytes def", encoding='UTF-8'), | |||
| "image3": bytes("image3 bytes ghi", encoding='UTF-8'), | |||
| "image4": bytes("image3 bytes jkl", encoding='UTF-8'), | |||
| "image5": bytes("image3 bytes mno", encoding='UTF-8')}, | |||
| {"file_name": "004.jpg", "label": 29, | |||
| "image1": bytes("image4 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image4 bytes def", encoding='UTF-8'), | |||
| "image3": bytes("image4 bytes ghi", encoding='UTF-8'), | |||
| "image4": bytes("image4 bytes jkl", encoding='UTF-8'), | |||
| "image5": bytes("image4 bytes mno", encoding='UTF-8')}, | |||
| {"file_name": "005.jpg", "label": 78, | |||
| "image1": bytes("image5 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image5 bytes def", encoding='UTF-8'), | |||
| "image3": bytes("image5 bytes ghi", encoding='UTF-8'), | |||
| "image4": bytes("image5 bytes jkl", encoding='UTF-8'), | |||
| "image5": bytes("image5 bytes mno", encoding='UTF-8')}, | |||
| {"file_name": "006.jpg", "label": 37, | |||
| "image1": bytes("image6 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image6 bytes def", encoding='UTF-8'), | |||
| "image3": bytes("image6 bytes ghi", encoding='UTF-8'), | |||
| "image4": bytes("image6 bytes jkl", encoding='UTF-8'), | |||
| "image5": bytes("image6 bytes mno", encoding='UTF-8')} | |||
| ] | |||
| writer = FileWriter(mindrecord_file_name) | |||
| schema = {"file_name": {"type": "string"}, | |||
| "image1": {"type": "bytes"}, | |||
| "image2": {"type": "bytes"}, | |||
| "image3": {"type": "bytes"}, | |||
| "label": {"type": "int32"}, | |||
| "image4": {"type": "bytes"}, | |||
| "image5": {"type": "bytes"}} | |||
| writer.add_schema(schema, "data is so cool") | |||
| writer.write_raw_data(data) | |||
| writer.commit() | |||
| # change data value to list | |||
| data_value_to_list = [] | |||
| for item in data: | |||
| new_data = {} | |||
| new_data['file_name'] = np.asarray(list(bytes(item["file_name"], encoding='utf-8')), dtype=np.uint8) | |||
| new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32) | |||
| new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8) | |||
| new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8) | |||
| new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8) | |||
| new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8) | |||
| new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8) | |||
| data_value_to_list.append(new_data) | |||
| num_readers = 2 | |||
| data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | |||
| num_parallel_workers=num_readers, | |||
| shuffle=False) | |||
| assert data_set.get_dataset_size() == 6 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| assert len(item) == 7 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| assert num_iter == 6 | |||
| num_readers = 2 | |||
| data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | |||
| columns_list=["image1", "image2", "image5"], | |||
| num_parallel_workers=num_readers, | |||
| shuffle=False) | |||
| assert data_set.get_dataset_size() == 6 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| assert len(item) == 3 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| assert num_iter == 6 | |||
| num_readers = 2 | |||
| data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | |||
| columns_list=["image2", "image4"], | |||
| num_parallel_workers=num_readers, | |||
| shuffle=False) | |||
| assert data_set.get_dataset_size() == 6 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| assert len(item) == 2 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| assert num_iter == 6 | |||
| num_readers = 2 | |||
| data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | |||
| columns_list=["image5", "image2"], | |||
| num_parallel_workers=num_readers, | |||
| shuffle=False) | |||
| assert data_set.get_dataset_size() == 6 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| assert len(item) == 2 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| assert num_iter == 6 | |||
| num_readers = 2 | |||
| data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | |||
| columns_list=["image5", "image2", "label"], | |||
| num_parallel_workers=num_readers, | |||
| shuffle=False) | |||
| assert data_set.get_dataset_size() == 6 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| assert len(item) == 3 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| assert num_iter == 6 | |||
| num_readers = 2 | |||
| data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | |||
| columns_list=["image4", "image5", "image2", "image3", "file_name"], | |||
| num_parallel_workers=num_readers, | |||
| shuffle=False) | |||
| assert data_set.get_dataset_size() == 6 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| assert len(item) == 5 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| assert num_iter == 6 | |||
| os.remove("{}".format(mindrecord_file_name)) | |||
| os.remove("{}.db".format(mindrecord_file_name)) | |||
| def test_write_with_multi_array_and_MindDataset(): | |||
| mindrecord_file_name = "test.mindrecord" | |||
| data = [{"source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int64), | |||
| "source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64), | |||
| "source_eos_ids": np.array([13, 14, 15, 16, 17, 18], dtype=np.int64), | |||
| "source_eos_mask": np.array([19, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), | |||
| "target_sos_ids": np.array([28, 29, 30, 31, 32], dtype=np.int64), | |||
| "target_sos_mask": np.array([33, 34, 35, 36, 37, 38], dtype=np.int64), | |||
| "target_eos_ids": np.array([39, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), | |||
| "target_eos_mask": np.array([48, 49, 50, 51], dtype=np.int64)}, | |||
| {"source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int64), | |||
| "source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64), | |||
| "source_eos_ids": np.array([113, 14, 15, 16, 17, 18], dtype=np.int64), | |||
| "source_eos_mask": np.array([119, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), | |||
| "target_sos_ids": np.array([128, 29, 30, 31, 32], dtype=np.int64), | |||
| "target_sos_mask": np.array([133, 34, 35, 36, 37, 38], dtype=np.int64), | |||
| "target_eos_ids": np.array([139, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), | |||
| "target_eos_mask": np.array([148, 49, 50, 51], dtype=np.int64)}, | |||
| {"source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int64), | |||
| "source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64), | |||
| "source_eos_ids": np.array([213, 14, 15, 16, 17, 18], dtype=np.int64), | |||
| "source_eos_mask": np.array([219, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), | |||
| "target_sos_ids": np.array([228, 29, 30, 31, 32], dtype=np.int64), | |||
| "target_sos_mask": np.array([233, 34, 35, 36, 37, 38], dtype=np.int64), | |||
| "target_eos_ids": np.array([239, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), | |||
| "target_eos_mask": np.array([248, 49, 50, 51], dtype=np.int64)}, | |||
| {"source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int64), | |||
| "source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64), | |||
| "source_eos_ids": np.array([313, 14, 15, 16, 17, 18], dtype=np.int64), | |||
| "source_eos_mask": np.array([319, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), | |||
| "target_sos_ids": np.array([328, 29, 30, 31, 32], dtype=np.int64), | |||
| "target_sos_mask": np.array([333, 34, 35, 36, 37, 38], dtype=np.int64), | |||
| "target_eos_ids": np.array([339, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), | |||
| "target_eos_mask": np.array([348, 49, 50, 51], dtype=np.int64)}, | |||
| {"source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int64), | |||
| "source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64), | |||
| "source_eos_ids": np.array([413, 14, 15, 16, 17, 18], dtype=np.int64), | |||
| "source_eos_mask": np.array([419, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), | |||
| "target_sos_ids": np.array([428, 29, 30, 31, 32], dtype=np.int64), | |||
| "target_sos_mask": np.array([433, 34, 35, 36, 37, 38], dtype=np.int64), | |||
| "target_eos_ids": np.array([439, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), | |||
| "target_eos_mask": np.array([448, 49, 50, 51], dtype=np.int64)}, | |||
| {"source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int64), | |||
| "source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64), | |||
| "source_eos_ids": np.array([513, 14, 15, 16, 17, 18], dtype=np.int64), | |||
| "source_eos_mask": np.array([519, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), | |||
| "target_sos_ids": np.array([528, 29, 30, 31, 32], dtype=np.int64), | |||
| "target_sos_mask": np.array([533, 34, 35, 36, 37, 38], dtype=np.int64), | |||
| "target_eos_ids": np.array([539, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), | |||
| "target_eos_mask": np.array([548, 49, 50, 51], dtype=np.int64)} | |||
| ] | |||
| writer = FileWriter(mindrecord_file_name) | |||
| schema = {"source_sos_ids": {"type": "int64", "shape": [-1]}, | |||
| "source_sos_mask": {"type": "int64", "shape": [-1]}, | |||
| "source_eos_ids": {"type": "int64", "shape": [-1]}, | |||
| "source_eos_mask": {"type": "int64", "shape": [-1]}, | |||
| "target_sos_ids": {"type": "int64", "shape": [-1]}, | |||
| "target_sos_mask": {"type": "int64", "shape": [-1]}, | |||
| "target_eos_ids": {"type": "int64", "shape": [-1]}, | |||
| "target_eos_mask": {"type": "int64", "shape": [-1]}} | |||
| writer.add_schema(schema, "data is so cool") | |||
| writer.write_raw_data(data) | |||
| writer.commit() | |||
| # change data value to list - do none | |||
| data_value_to_list = [] | |||
| for item in data: | |||
| new_data = {} | |||
| new_data['source_sos_ids'] = item["source_sos_ids"] | |||
| new_data['source_sos_mask'] = item["source_sos_mask"] | |||
| new_data['source_eos_ids'] = item["source_eos_ids"] | |||
| new_data['source_eos_mask'] = item["source_eos_mask"] | |||
| new_data['target_sos_ids'] = item["target_sos_ids"] | |||
| new_data['target_sos_mask'] = item["target_sos_mask"] | |||
| new_data['target_eos_ids'] = item["target_eos_ids"] | |||
| new_data['target_eos_mask'] = item["target_eos_mask"] | |||
| data_value_to_list.append(new_data) | |||
| num_readers = 2 | |||
| data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | |||
| num_parallel_workers=num_readers, | |||
| shuffle=False) | |||
| assert data_set.get_dataset_size() == 6 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| assert len(item) == 8 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| assert num_iter == 6 | |||
| num_readers = 2 | |||
| data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | |||
| columns_list=["source_eos_ids", "source_eos_mask", | |||
| "target_sos_ids", "target_sos_mask", | |||
| "target_eos_ids", "target_eos_mask"], | |||
| num_parallel_workers=num_readers, | |||
| shuffle=False) | |||
| assert data_set.get_dataset_size() == 6 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| assert len(item) == 6 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| assert num_iter == 6 | |||
| num_readers = 2 | |||
| data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | |||
| columns_list=["source_sos_ids", | |||
| "target_sos_ids", | |||
| "target_eos_mask"], | |||
| num_parallel_workers=num_readers, | |||
| shuffle=False) | |||
| assert data_set.get_dataset_size() == 6 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| assert len(item) == 3 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| assert num_iter == 6 | |||
| num_readers = 2 | |||
| data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | |||
| columns_list=["target_eos_mask", | |||
| "source_eos_mask", | |||
| "source_sos_mask"], | |||
| num_parallel_workers=num_readers, | |||
| shuffle=False) | |||
| assert data_set.get_dataset_size() == 6 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| assert len(item) == 3 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| assert num_iter == 6 | |||
| num_readers = 2 | |||
| data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | |||
| columns_list=["target_eos_ids"], | |||
| num_parallel_workers=num_readers, | |||
| shuffle=False) | |||
| assert data_set.get_dataset_size() == 6 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| assert len(item) == 1 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| assert num_iter == 6 | |||
| num_readers = 1 | |||
| data_set = ds.MindDataset(dataset_file=mindrecord_file_name, | |||
| columns_list=["target_eos_mask", "target_eos_ids", | |||
| "target_sos_mask", "target_sos_ids", | |||
| "source_eos_mask", "source_eos_ids", | |||
| "source_sos_mask", "source_sos_ids"], | |||
| num_parallel_workers=num_readers, | |||
| shuffle=False) | |||
| assert data_set.get_dataset_size() == 6 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| assert len(item) == 8 | |||
| for field in item: | |||
| if isinstance(item[field], np.ndarray): | |||
| assert (item[field] == data_value_to_list[num_iter][field]).all() | |||
| else: | |||
| assert item[field] == data_value_to_list[num_iter][field] | |||
| num_iter += 1 | |||
| assert num_iter == 6 | |||
| os.remove("{}".format(mindrecord_file_name)) | |||
| os.remove("{}.db".format(mindrecord_file_name)) | |||
| @@ -448,3 +448,456 @@ def test_cv_file_writer_no_raw(): | |||
| reader.close() | |||
| os.remove(NLP_FILE_NAME) | |||
| os.remove("{}.db".format(NLP_FILE_NAME)) | |||
| def test_write_read_process_with_multi_bytes(): | |||
| mindrecord_file_name = "test.mindrecord" | |||
| data = [{"file_name": "001.jpg", "label": 43, | |||
| "image1": bytes("image1 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image1 bytes def", encoding='UTF-8'), | |||
| "image3": bytes("image1 bytes ghi", encoding='UTF-8'), | |||
| "image4": bytes("image1 bytes jkl", encoding='UTF-8'), | |||
| "image5": bytes("image1 bytes mno", encoding='UTF-8')}, | |||
| {"file_name": "002.jpg", "label": 91, | |||
| "image1": bytes("image2 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image2 bytes def", encoding='UTF-8'), | |||
| "image3": bytes("image2 bytes ghi", encoding='UTF-8'), | |||
| "image4": bytes("image2 bytes jkl", encoding='UTF-8'), | |||
| "image5": bytes("image2 bytes mno", encoding='UTF-8')}, | |||
| {"file_name": "003.jpg", "label": 61, | |||
| "image1": bytes("image3 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image3 bytes def", encoding='UTF-8'), | |||
| "image3": bytes("image3 bytes ghi", encoding='UTF-8'), | |||
| "image4": bytes("image3 bytes jkl", encoding='UTF-8'), | |||
| "image5": bytes("image3 bytes mno", encoding='UTF-8')}, | |||
| {"file_name": "004.jpg", "label": 29, | |||
| "image1": bytes("image4 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image4 bytes def", encoding='UTF-8'), | |||
| "image3": bytes("image4 bytes ghi", encoding='UTF-8'), | |||
| "image4": bytes("image4 bytes jkl", encoding='UTF-8'), | |||
| "image5": bytes("image4 bytes mno", encoding='UTF-8')}, | |||
| {"file_name": "005.jpg", "label": 78, | |||
| "image1": bytes("image5 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image5 bytes def", encoding='UTF-8'), | |||
| "image3": bytes("image5 bytes ghi", encoding='UTF-8'), | |||
| "image4": bytes("image5 bytes jkl", encoding='UTF-8'), | |||
| "image5": bytes("image5 bytes mno", encoding='UTF-8')}, | |||
| {"file_name": "006.jpg", "label": 37, | |||
| "image1": bytes("image6 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image6 bytes def", encoding='UTF-8'), | |||
| "image3": bytes("image6 bytes ghi", encoding='UTF-8'), | |||
| "image4": bytes("image6 bytes jkl", encoding='UTF-8'), | |||
| "image5": bytes("image6 bytes mno", encoding='UTF-8')} | |||
| ] | |||
| writer = FileWriter(mindrecord_file_name) | |||
| schema = {"file_name": {"type": "string"}, | |||
| "image1": {"type": "bytes"}, | |||
| "image2": {"type": "bytes"}, | |||
| "image3": {"type": "bytes"}, | |||
| "label": {"type": "int32"}, | |||
| "image4": {"type": "bytes"}, | |||
| "image5": {"type": "bytes"}} | |||
| writer.add_schema(schema, "data is so cool") | |||
| writer.write_raw_data(data) | |||
| writer.commit() | |||
| reader = FileReader(mindrecord_file_name) | |||
| count = 0 | |||
| for index, x in enumerate(reader.get_next()): | |||
| assert len(x) == 7 | |||
| for field in x: | |||
| if isinstance(x[field], np.ndarray): | |||
| assert (x[field] == data[count][field]).all() | |||
| else: | |||
| assert x[field] == data[count][field] | |||
| count = count + 1 | |||
| logger.info("#item{}: {}".format(index, x)) | |||
| assert count == 6 | |||
| reader.close() | |||
| reader2 = FileReader(file_name=mindrecord_file_name, columns=["image1", "image2", "image5"]) | |||
| count = 0 | |||
| for index, x in enumerate(reader2.get_next()): | |||
| assert len(x) == 3 | |||
| for field in x: | |||
| if isinstance(x[field], np.ndarray): | |||
| assert (x[field] == data[count][field]).all() | |||
| else: | |||
| assert x[field] == data[count][field] | |||
| count = count + 1 | |||
| logger.info("#item{}: {}".format(index, x)) | |||
| assert count == 6 | |||
| reader2.close() | |||
| reader3 = FileReader(file_name=mindrecord_file_name, columns=["image2", "image4"]) | |||
| count = 0 | |||
| for index, x in enumerate(reader3.get_next()): | |||
| assert len(x) == 2 | |||
| for field in x: | |||
| if isinstance(x[field], np.ndarray): | |||
| assert (x[field] == data[count][field]).all() | |||
| else: | |||
| assert x[field] == data[count][field] | |||
| count = count + 1 | |||
| logger.info("#item{}: {}".format(index, x)) | |||
| assert count == 6 | |||
| reader3.close() | |||
| reader4 = FileReader(file_name=mindrecord_file_name, columns=["image5", "image2"]) | |||
| count = 0 | |||
| for index, x in enumerate(reader4.get_next()): | |||
| assert len(x) == 2 | |||
| for field in x: | |||
| if isinstance(x[field], np.ndarray): | |||
| assert (x[field] == data[count][field]).all() | |||
| else: | |||
| assert x[field] == data[count][field] | |||
| count = count + 1 | |||
| logger.info("#item{}: {}".format(index, x)) | |||
| assert count == 6 | |||
| reader4.close() | |||
| reader5 = FileReader(file_name=mindrecord_file_name, columns=["image5", "image2", "label"]) | |||
| count = 0 | |||
| for index, x in enumerate(reader5.get_next()): | |||
| assert len(x) == 3 | |||
| for field in x: | |||
| if isinstance(x[field], np.ndarray): | |||
| assert (x[field] == data[count][field]).all() | |||
| else: | |||
| assert x[field] == data[count][field] | |||
| count = count + 1 | |||
| logger.info("#item{}: {}".format(index, x)) | |||
| assert count == 6 | |||
| reader5.close() | |||
| os.remove("{}".format(mindrecord_file_name)) | |||
| os.remove("{}.db".format(mindrecord_file_name)) | |||
| def test_write_read_process_with_multi_array(): | |||
| mindrecord_file_name = "test.mindrecord" | |||
| data = [{"source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int64), | |||
| "source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64), | |||
| "source_eos_ids": np.array([13, 14, 15, 16, 17, 18], dtype=np.int64), | |||
| "source_eos_mask": np.array([19, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), | |||
| "target_sos_ids": np.array([28, 29, 30, 31, 32], dtype=np.int64), | |||
| "target_sos_mask": np.array([33, 34, 35, 36, 37, 38], dtype=np.int64), | |||
| "target_eos_ids": np.array([39, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), | |||
| "target_eos_mask": np.array([48, 49, 50, 51], dtype=np.int64)}, | |||
| {"source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int64), | |||
| "source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64), | |||
| "source_eos_ids": np.array([113, 14, 15, 16, 17, 18], dtype=np.int64), | |||
| "source_eos_mask": np.array([119, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), | |||
| "target_sos_ids": np.array([128, 29, 30, 31, 32], dtype=np.int64), | |||
| "target_sos_mask": np.array([133, 34, 35, 36, 37, 38], dtype=np.int64), | |||
| "target_eos_ids": np.array([139, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), | |||
| "target_eos_mask": np.array([148, 49, 50, 51], dtype=np.int64)}, | |||
| {"source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int64), | |||
| "source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64), | |||
| "source_eos_ids": np.array([213, 14, 15, 16, 17, 18], dtype=np.int64), | |||
| "source_eos_mask": np.array([219, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), | |||
| "target_sos_ids": np.array([228, 29, 30, 31, 32], dtype=np.int64), | |||
| "target_sos_mask": np.array([233, 34, 35, 36, 37, 38], dtype=np.int64), | |||
| "target_eos_ids": np.array([239, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), | |||
| "target_eos_mask": np.array([248, 49, 50, 51], dtype=np.int64)}, | |||
| {"source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int64), | |||
| "source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64), | |||
| "source_eos_ids": np.array([313, 14, 15, 16, 17, 18], dtype=np.int64), | |||
| "source_eos_mask": np.array([319, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), | |||
| "target_sos_ids": np.array([328, 29, 30, 31, 32], dtype=np.int64), | |||
| "target_sos_mask": np.array([333, 34, 35, 36, 37, 38], dtype=np.int64), | |||
| "target_eos_ids": np.array([339, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), | |||
| "target_eos_mask": np.array([348, 49, 50, 51], dtype=np.int64)}, | |||
| {"source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int64), | |||
| "source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64), | |||
| "source_eos_ids": np.array([413, 14, 15, 16, 17, 18], dtype=np.int64), | |||
| "source_eos_mask": np.array([419, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), | |||
| "target_sos_ids": np.array([428, 29, 30, 31, 32], dtype=np.int64), | |||
| "target_sos_mask": np.array([433, 34, 35, 36, 37, 38], dtype=np.int64), | |||
| "target_eos_ids": np.array([439, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), | |||
| "target_eos_mask": np.array([448, 49, 50, 51], dtype=np.int64)}, | |||
| {"source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int64), | |||
| "source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64), | |||
| "source_eos_ids": np.array([513, 14, 15, 16, 17, 18], dtype=np.int64), | |||
| "source_eos_mask": np.array([519, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), | |||
| "target_sos_ids": np.array([528, 29, 30, 31, 32], dtype=np.int64), | |||
| "target_sos_mask": np.array([533, 34, 35, 36, 37, 38], dtype=np.int64), | |||
| "target_eos_ids": np.array([539, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), | |||
| "target_eos_mask": np.array([548, 49, 50, 51], dtype=np.int64)} | |||
| ] | |||
| writer = FileWriter(mindrecord_file_name) | |||
| schema = {"source_sos_ids": {"type": "int64", "shape": [-1]}, | |||
| "source_sos_mask": {"type": "int64", "shape": [-1]}, | |||
| "source_eos_ids": {"type": "int64", "shape": [-1]}, | |||
| "source_eos_mask": {"type": "int64", "shape": [-1]}, | |||
| "target_sos_ids": {"type": "int64", "shape": [-1]}, | |||
| "target_sos_mask": {"type": "int64", "shape": [-1]}, | |||
| "target_eos_ids": {"type": "int64", "shape": [-1]}, | |||
| "target_eos_mask": {"type": "int64", "shape": [-1]}} | |||
| writer.add_schema(schema, "data is so cool") | |||
| writer.write_raw_data(data) | |||
| writer.commit() | |||
| reader = FileReader(mindrecord_file_name) | |||
| count = 0 | |||
| for index, x in enumerate(reader.get_next()): | |||
| assert len(x) == 8 | |||
| for field in x: | |||
| if isinstance(x[field], np.ndarray): | |||
| assert (x[field] == data[count][field]).all() | |||
| else: | |||
| assert x[field] == data[count][field] | |||
| count = count + 1 | |||
| logger.info("#item{}: {}".format(index, x)) | |||
| assert count == 6 | |||
| reader.close() | |||
| reader = FileReader(file_name=mindrecord_file_name, columns=["source_eos_ids", "source_eos_mask", | |||
| "target_sos_ids", "target_sos_mask", | |||
| "target_eos_ids", "target_eos_mask"]) | |||
| count = 0 | |||
| for index, x in enumerate(reader.get_next()): | |||
| assert len(x) == 6 | |||
| for field in x: | |||
| if isinstance(x[field], np.ndarray): | |||
| assert (x[field] == data[count][field]).all() | |||
| else: | |||
| assert x[field] == data[count][field] | |||
| count = count + 1 | |||
| logger.info("#item{}: {}".format(index, x)) | |||
| assert count == 6 | |||
| reader.close() | |||
| reader = FileReader(file_name=mindrecord_file_name, columns=["source_sos_ids", | |||
| "target_sos_ids", | |||
| "target_eos_mask"]) | |||
| count = 0 | |||
| for index, x in enumerate(reader.get_next()): | |||
| assert len(x) == 3 | |||
| for field in x: | |||
| if isinstance(x[field], np.ndarray): | |||
| assert (x[field] == data[count][field]).all() | |||
| else: | |||
| assert x[field] == data[count][field] | |||
| count = count + 1 | |||
| logger.info("#item{}: {}".format(index, x)) | |||
| assert count == 6 | |||
| reader.close() | |||
| reader = FileReader(file_name=mindrecord_file_name, columns=["target_eos_mask", | |||
| "source_eos_mask", | |||
| "source_sos_mask"]) | |||
| count = 0 | |||
| for index, x in enumerate(reader.get_next()): | |||
| assert len(x) == 3 | |||
| for field in x: | |||
| if isinstance(x[field], np.ndarray): | |||
| assert (x[field] == data[count][field]).all() | |||
| else: | |||
| assert x[field] == data[count][field] | |||
| count = count + 1 | |||
| logger.info("#item{}: {}".format(index, x)) | |||
| assert count == 6 | |||
| reader.close() | |||
| reader = FileReader(file_name=mindrecord_file_name, columns=["target_eos_ids"]) | |||
| count = 0 | |||
| for index, x in enumerate(reader.get_next()): | |||
| assert len(x) == 1 | |||
| for field in x: | |||
| if isinstance(x[field], np.ndarray): | |||
| assert (x[field] == data[count][field]).all() | |||
| else: | |||
| assert x[field] == data[count][field] | |||
| count = count + 1 | |||
| logger.info("#item{}: {}".format(index, x)) | |||
| assert count == 6 | |||
| reader.close() | |||
| os.remove("{}".format(mindrecord_file_name)) | |||
| os.remove("{}.db".format(mindrecord_file_name)) | |||
| def test_write_read_process_with_multi_bytes_and_array(): | |||
| mindrecord_file_name = "test.mindrecord" | |||
| data = [{"file_name": "001.jpg", "label": 4, | |||
| "image1": bytes("image1 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image1 bytes def", encoding='UTF-8'), | |||
| "source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int64), | |||
| "source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64), | |||
| "image3": bytes("image1 bytes ghi", encoding='UTF-8'), | |||
| "image4": bytes("image1 bytes jkl", encoding='UTF-8'), | |||
| "image5": bytes("image1 bytes mno", encoding='UTF-8'), | |||
| "target_sos_ids": np.array([28, 29, 30, 31, 32], dtype=np.int64), | |||
| "target_sos_mask": np.array([33, 34, 35, 36, 37, 38], dtype=np.int64), | |||
| "target_eos_ids": np.array([39, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), | |||
| "target_eos_mask": np.array([48, 49, 50, 51], dtype=np.int64)}, | |||
| {"file_name": "002.jpg", "label": 5, | |||
| "image1": bytes("image2 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image2 bytes def", encoding='UTF-8'), | |||
| "image3": bytes("image2 bytes ghi", encoding='UTF-8'), | |||
| "image4": bytes("image2 bytes jkl", encoding='UTF-8'), | |||
| "image5": bytes("image2 bytes mno", encoding='UTF-8'), | |||
| "source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int64), | |||
| "source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64), | |||
| "target_sos_ids": np.array([128, 29, 30, 31, 32], dtype=np.int64), | |||
| "target_sos_mask": np.array([133, 34, 35, 36, 37, 38], dtype=np.int64), | |||
| "target_eos_ids": np.array([139, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), | |||
| "target_eos_mask": np.array([148, 49, 50, 51], dtype=np.int64)}, | |||
| {"file_name": "003.jpg", "label": 6, | |||
| "source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int64), | |||
| "source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64), | |||
| "target_sos_ids": np.array([228, 29, 30, 31, 32], dtype=np.int64), | |||
| "target_sos_mask": np.array([233, 34, 35, 36, 37, 38], dtype=np.int64), | |||
| "target_eos_ids": np.array([239, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), | |||
| "image1": bytes("image3 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image3 bytes def", encoding='UTF-8'), | |||
| "image3": bytes("image3 bytes ghi", encoding='UTF-8'), | |||
| "image4": bytes("image3 bytes jkl", encoding='UTF-8'), | |||
| "image5": bytes("image3 bytes mno", encoding='UTF-8'), | |||
| "target_eos_mask": np.array([248, 49, 50, 51], dtype=np.int64)}, | |||
| {"file_name": "004.jpg", "label": 7, | |||
| "source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int64), | |||
| "source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64), | |||
| "image1": bytes("image4 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image4 bytes def", encoding='UTF-8'), | |||
| "image3": bytes("image4 bytes ghi", encoding='UTF-8'), | |||
| "image4": bytes("image4 bytes jkl", encoding='UTF-8'), | |||
| "image5": bytes("image4 bytes mno", encoding='UTF-8'), | |||
| "target_sos_ids": np.array([328, 29, 30, 31, 32], dtype=np.int64), | |||
| "target_sos_mask": np.array([333, 34, 35, 36, 37, 38], dtype=np.int64), | |||
| "target_eos_ids": np.array([339, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), | |||
| "target_eos_mask": np.array([348, 49, 50, 51], dtype=np.int64)}, | |||
| {"file_name": "005.jpg", "label": 8, | |||
| "source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int64), | |||
| "source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64), | |||
| "target_sos_ids": np.array([428, 29, 30, 31, 32], dtype=np.int64), | |||
| "target_sos_mask": np.array([433, 34, 35, 36, 37, 38], dtype=np.int64), | |||
| "image1": bytes("image5 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image5 bytes def", encoding='UTF-8'), | |||
| "image3": bytes("image5 bytes ghi", encoding='UTF-8'), | |||
| "image4": bytes("image5 bytes jkl", encoding='UTF-8'), | |||
| "image5": bytes("image5 bytes mno", encoding='UTF-8'), | |||
| "target_eos_ids": np.array([439, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), | |||
| "target_eos_mask": np.array([448, 49, 50, 51], dtype=np.int64)}, | |||
| {"file_name": "006.jpg", "label": 9, | |||
| "source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int64), | |||
| "source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64), | |||
| "target_sos_ids": np.array([528, 29, 30, 31, 32], dtype=np.int64), | |||
| "image1": bytes("image6 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image6 bytes def", encoding='UTF-8'), | |||
| "image3": bytes("image6 bytes ghi", encoding='UTF-8'), | |||
| "image4": bytes("image6 bytes jkl", encoding='UTF-8'), | |||
| "image5": bytes("image6 bytes mno", encoding='UTF-8'), | |||
| "target_sos_mask": np.array([533, 34, 35, 36, 37, 38], dtype=np.int64), | |||
| "target_eos_ids": np.array([539, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), | |||
| "target_eos_mask": np.array([548, 49, 50, 51], dtype=np.int64)} | |||
| ] | |||
| writer = FileWriter(mindrecord_file_name) | |||
| schema = {"file_name": {"type": "string"}, | |||
| "image1": {"type": "bytes"}, | |||
| "image2": {"type": "bytes"}, | |||
| "source_sos_ids": {"type": "int64", "shape": [-1]}, | |||
| "source_sos_mask": {"type": "int64", "shape": [-1]}, | |||
| "image3": {"type": "bytes"}, | |||
| "image4": {"type": "bytes"}, | |||
| "image5": {"type": "bytes"}, | |||
| "target_sos_ids": {"type": "int64", "shape": [-1]}, | |||
| "target_sos_mask": {"type": "int64", "shape": [-1]}, | |||
| "target_eos_ids": {"type": "int64", "shape": [-1]}, | |||
| "target_eos_mask": {"type": "int64", "shape": [-1]}, | |||
| "label": {"type": "int32"}} | |||
| writer.add_schema(schema, "data is so cool") | |||
| writer.write_raw_data(data) | |||
| writer.commit() | |||
| reader = FileReader(mindrecord_file_name) | |||
| count = 0 | |||
| for index, x in enumerate(reader.get_next()): | |||
| assert len(x) == 13 | |||
| for field in x: | |||
| if isinstance(x[field], np.ndarray): | |||
| assert (x[field] == data[count][field]).all() | |||
| else: | |||
| assert x[field] == data[count][field] | |||
| count = count + 1 | |||
| logger.info("#item{}: {}".format(index, x)) | |||
| assert count == 6 | |||
| reader.close() | |||
| reader = FileReader(file_name=mindrecord_file_name, columns=["source_sos_ids", "source_sos_mask", | |||
| "target_sos_ids"]) | |||
| count = 0 | |||
| for index, x in enumerate(reader.get_next()): | |||
| assert len(x) == 3 | |||
| for field in x: | |||
| if isinstance(x[field], np.ndarray): | |||
| assert (x[field] == data[count][field]).all() | |||
| else: | |||
| assert x[field] == data[count][field] | |||
| count = count + 1 | |||
| logger.info("#item{}: {}".format(index, x)) | |||
| assert count == 6 | |||
| reader.close() | |||
| reader = FileReader(file_name=mindrecord_file_name, columns=["image2", "source_sos_mask", | |||
| "image3", "target_sos_ids"]) | |||
| count = 0 | |||
| for index, x in enumerate(reader.get_next()): | |||
| assert len(x) == 4 | |||
| for field in x: | |||
| if isinstance(x[field], np.ndarray): | |||
| assert (x[field] == data[count][field]).all() | |||
| else: | |||
| assert x[field] == data[count][field] | |||
| count = count + 1 | |||
| logger.info("#item{}: {}".format(index, x)) | |||
| assert count == 6 | |||
| reader.close() | |||
| reader = FileReader(file_name=mindrecord_file_name, columns=["target_sos_ids", "image4", | |||
| "source_sos_ids"]) | |||
| count = 0 | |||
| for index, x in enumerate(reader.get_next()): | |||
| assert len(x) == 3 | |||
| for field in x: | |||
| if isinstance(x[field], np.ndarray): | |||
| assert (x[field] == data[count][field]).all() | |||
| else: | |||
| assert x[field] == data[count][field] | |||
| count = count + 1 | |||
| logger.info("#item{}: {}".format(index, x)) | |||
| assert count == 6 | |||
| reader.close() | |||
| reader = FileReader(file_name=mindrecord_file_name, columns=["target_sos_ids", "image5", | |||
| "image4", "image3", "source_sos_ids"]) | |||
| count = 0 | |||
| for index, x in enumerate(reader.get_next()): | |||
| assert len(x) == 5 | |||
| for field in x: | |||
| if isinstance(x[field], np.ndarray): | |||
| assert (x[field] == data[count][field]).all() | |||
| else: | |||
| assert x[field] == data[count][field] | |||
| count = count + 1 | |||
| logger.info("#item{}: {}".format(index, x)) | |||
| assert count == 6 | |||
| reader.close() | |||
| reader = FileReader(file_name=mindrecord_file_name, columns=["target_eos_mask", "image5", "image2", | |||
| "source_sos_mask", "label"]) | |||
| count = 0 | |||
| for index, x in enumerate(reader.get_next()): | |||
| assert len(x) == 5 | |||
| for field in x: | |||
| if isinstance(x[field], np.ndarray): | |||
| assert (x[field] == data[count][field]).all() | |||
| else: | |||
| assert x[field] == data[count][field] | |||
| count = count + 1 | |||
| logger.info("#item{}: {}".format(index, x)) | |||
| assert count == 6 | |||
| reader.close() | |||
| os.remove("{}".format(mindrecord_file_name)) | |||
| os.remove("{}.db".format(mindrecord_file_name)) | |||