|
- # Copyright 2019 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- import mindspore._c_dataengine as cde
- import numpy as np
- import pytest
-
- from mindspore.dataset.text import to_str, to_bytes
-
- import mindspore.dataset as ds
- import mindspore.common.dtype as mstype
-
-
- # pylint: disable=comparison-with-itself
- def test_basic():
- x = np.array([["ab", "cde", "121"], ["x", "km", "789"]], dtype='S')
- n = cde.Tensor(x)
- arr = n.as_array()
- np.testing.assert_array_equal(x, arr)
-
-
- def compare(strings):
- arr = np.array(strings, dtype='S')
-
- def gen():
- yield arr,
-
- data = ds.GeneratorDataset(gen, column_names=["col"])
-
- for d in data:
- np.testing.assert_array_equal(d[0], arr)
-
-
- def test_generator():
- compare(["ab"])
- compare(["ab", "cde", "121"])
- compare([["ab", "cde", "121"], ["x", "km", "789"]])
-
-
- def test_batching_strings():
- def gen():
- yield np.array(["ab", "cde", "121"], dtype='S'),
-
- data = ds.GeneratorDataset(gen, column_names=["col"]).batch(10)
-
- with pytest.raises(RuntimeError) as info:
- for _ in data:
- pass
- assert "[Batch ERROR] Batch does not support" in str(info.value)
-
-
- def test_map():
- def gen():
- yield np.array(["ab cde 121"], dtype='S'),
-
- data = ds.GeneratorDataset(gen, column_names=["col"])
-
- def split(b):
- s = to_str(b)
- splits = s.item().split()
- return np.array(splits, dtype='S')
-
- data = data.map(input_columns=["col"], operations=split)
- expected = np.array(["ab", "cde", "121"], dtype='S')
- for d in data:
- np.testing.assert_array_equal(d[0], expected)
-
-
- def test_map2():
- def gen():
- yield np.array(["ab cde 121"], dtype='S'),
-
- data = ds.GeneratorDataset(gen, column_names=["col"])
-
- def upper(b):
- out = np.char.upper(b)
- return out
-
- data = data.map(input_columns=["col"], operations=upper)
- expected = np.array(["AB CDE 121"], dtype='S')
- for d in data:
- np.testing.assert_array_equal(d[0], expected)
-
-
- line = np.array(["This is a text file.",
- "Be happy every day.",
- "Good luck to everyone."])
-
- words = np.array([["This", "text", "file", "a"],
- ["Be", "happy", "day", "b"],
- ["女", "", "everyone", "c"]])
-
- chinese = np.array(["今天天气太好了我们一起去外面玩吧",
- "男默女泪",
- "江州市长江大桥参加了长江大桥的通车仪式"])
-
-
- def test_tfrecord1():
- s = ds.Schema()
- s.add_column("line", "string", [])
- s.add_column("words", "string", [-1])
- s.add_column("chinese", "string", [])
-
- data = ds.TFRecordDataset("../data/dataset/testTextTFRecord/text.tfrecord", shuffle=False, schema=s)
-
- for i, d in enumerate(data.create_dict_iterator()):
- assert d["line"].shape == line[i].shape
- assert d["words"].shape == words[i].shape
- assert d["chinese"].shape == chinese[i].shape
- np.testing.assert_array_equal(line[i], to_str(d["line"]))
- np.testing.assert_array_equal(words[i], to_str(d["words"]))
- np.testing.assert_array_equal(chinese[i], to_str(d["chinese"]))
-
-
- def test_tfrecord2():
- data = ds.TFRecordDataset("../data/dataset/testTextTFRecord/text.tfrecord", shuffle=False,
- schema='../data/dataset/testTextTFRecord/datasetSchema.json')
- for i, d in enumerate(data.create_dict_iterator()):
- assert d["line"].shape == line[i].shape
- assert d["words"].shape == words[i].shape
- assert d["chinese"].shape == chinese[i].shape
- np.testing.assert_array_equal(line[i], to_str(d["line"]))
- np.testing.assert_array_equal(words[i], to_str(d["words"]))
- np.testing.assert_array_equal(chinese[i], to_str(d["chinese"]))
-
-
- def test_tfrecord3():
- s = ds.Schema()
- s.add_column("line", mstype.string, [])
- s.add_column("words", mstype.string, [-1, 2])
- s.add_column("chinese", mstype.string, [])
-
- data = ds.TFRecordDataset("../data/dataset/testTextTFRecord/text.tfrecord", shuffle=False, schema=s)
-
- for i, d in enumerate(data.create_dict_iterator()):
- assert d["line"].shape == line[i].shape
- assert d["words"].shape == words[i].reshape([2, 2]).shape
- assert d["chinese"].shape == chinese[i].shape
- np.testing.assert_array_equal(line[i], to_str(d["line"]))
- np.testing.assert_array_equal(words[i].reshape([2, 2]), to_str(d["words"]))
- np.testing.assert_array_equal(chinese[i], to_str(d["chinese"]))
-
-
- def create_text_mindrecord():
- # methood to create mindrecord with string data, used to generate testTextMindRecord/test.mindrecord
- from mindspore.mindrecord import FileWriter
-
- mindrecord_file_name = "test.mindrecord"
- data = [{"english": "This is a text file.",
- "chinese": "今天天气太好了我们一起去外面玩吧"},
- {"english": "Be happy every day.",
- "chinese": "男默女泪"},
- {"english": "Good luck to everyone.",
- "chinese": "江州市长江大桥参加了长江大桥的通车仪式"},
- ]
- writer = FileWriter(mindrecord_file_name)
- schema = {"english": {"type": "string"},
- "chinese": {"type": "string"},
- }
- writer.add_schema(schema)
- writer.write_raw_data(data)
- writer.commit()
-
-
- def test_mindrecord():
- data = ds.MindDataset("../data/dataset/testTextMindRecord/test.mindrecord", shuffle=False)
-
- for i, d in enumerate(data.create_dict_iterator()):
- assert d["english"].shape == line[i].shape
- assert d["chinese"].shape == chinese[i].shape
- np.testing.assert_array_equal(line[i], to_str(d["english"]))
- np.testing.assert_array_equal(chinese[i], to_str(d["chinese"]))
-
-
- if __name__ == '__main__':
- test_generator()
- test_basic()
- test_batching_strings()
- test_map()
- test_map2()
- test_tfrecord1()
- test_tfrecord2()
- test_tfrecord3()
- test_mindrecord()
|