|
- # Copyright 2019 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- import mindspore._c_dataengine as cde
- import numpy as np
- import pytest
-
- import mindspore.dataset as ds
-
-
- def test_basic():
- x = np.array([["ab", "cde", "121"], ["x", "km", "789"]], dtype='S')
- # x = np.array(["ab", "cde"], dtype='S')
- n = cde.Tensor(x)
- arr = n.as_array()
- y = np.array([1, 2])
- assert all(y == y)
- # assert np.testing.assert_array_equal(y,y)
-
-
- def compare(strings):
- arr = np.array(strings, dtype='S')
-
- def gen():
- yield arr,
-
- data = ds.GeneratorDataset(gen, column_names=["col"])
-
- for d in data:
- np.testing.assert_array_equal(d[0], arr)
-
-
- def test_generator():
- compare(["ab"])
- compare(["ab", "cde", "121"])
- compare([["ab", "cde", "121"], ["x", "km", "789"]])
-
-
- def test_batching_strings():
- def gen():
- yield np.array(["ab", "cde", "121"], dtype='S'),
-
- data = ds.GeneratorDataset(gen, column_names=["col"]).batch(10)
-
- with pytest.raises(RuntimeError) as info:
- for _ in data:
- pass
- assert "[Batch ERROR] Batch does not support" in str(info)
-
-
- if __name__ == '__main__':
- test_generator()
- test_basic()
- test_batching_strings()
|