From 847e59bd6c35bd3e28d72f170fc891638ac4d73c Mon Sep 17 00:00:00 2001 From: jonyguo Date: Tue, 1 Dec 2020 17:30:57 +0800 Subject: [PATCH] fix: ndarray field without type in mindrecord --- mindspore/mindrecord/shardwriter.py | 4 +-- .../python/mindrecord/test_mindrecord_base.py | 36 +++++++++++++++++++ 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/mindspore/mindrecord/shardwriter.py b/mindspore/mindrecord/shardwriter.py index 49391d13ce..37e453a6a7 100644 --- a/mindspore/mindrecord/shardwriter.py +++ b/mindspore/mindrecord/shardwriter.py @@ -196,10 +196,10 @@ class ShardWriter: def int_to_bytes(x: int) -> bytes: return x.to_bytes(8, 'big') merged = bytes() - for _, v in blob_data.items(): + for field, v in blob_data.items(): # convert ndarray to bytes if isinstance(v, np.ndarray): - v = v.tobytes() + v = v.astype(self._header.schema[field]["type"]).tobytes() merged += int_to_bytes(len(v)) merged += v return merged diff --git a/tests/ut/python/mindrecord/test_mindrecord_base.py b/tests/ut/python/mindrecord/test_mindrecord_base.py index 3b25cf73a4..844174dc7a 100644 --- a/tests/ut/python/mindrecord/test_mindrecord_base.py +++ b/tests/ut/python/mindrecord/test_mindrecord_base.py @@ -964,3 +964,39 @@ def test_write_read_process_with_multi_bytes_and_array(): os.remove("{}".format(mindrecord_file_name)) os.remove("{}.db".format(mindrecord_file_name)) + +def test_write_read_process_without_ndarray_type(): + mindrecord_file_name = "test.mindrecord" + # field: mask derivation type is int64, but schema type is int32 + data = [{"file_name": "001.jpg", "label": 43, "score": 0.8, "mask": np.array([3, 6, 9]), + "segments": np.array([[5.0, 1.6], [65.2, 8.3]], dtype=np.float32), + "data": bytes("image bytes abc", encoding='UTF-8')} + ] + writer = FileWriter(mindrecord_file_name) + schema = {"file_name": {"type": "string"}, + "label": {"type": "int32"}, + "score": {"type": "float64"}, + "mask": {"type": "int32", "shape": [-1]}, + "segments": {"type": "float32", "shape": [2, 2]}, + "data": {"type": "bytes"}} + writer.add_schema(schema, "data is so cool") + writer.write_raw_data(data) + writer.commit() + + reader = FileReader(mindrecord_file_name) + count = 0 + for index, x in enumerate(reader.get_next()): + assert len(x) == 6 + for field in x: + if isinstance(x[field], np.ndarray): + print("output: {}, input: {}".format(x[field], data[count][field])) + assert (x[field] == data[count][field]).all() + else: + assert x[field] == data[count][field] + count = count + 1 + logger.info("#item{}: {}".format(index, x)) + assert count == 1 + reader.close() + + os.remove("{}".format(mindrecord_file_name)) + os.remove("{}.db".format(mindrecord_file_name))