You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_to_number_op.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore.common.dtype as mstype
  18. import mindspore.dataset as ds
  19. import mindspore.dataset.text as text
  20. np_integral_types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16,
  21. np.uint32, np.uint64]
  22. ms_integral_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8,
  23. mstype.uint16, mstype.uint32, mstype.uint64]
  24. np_non_integral_types = [np.float16, np.float32, np.float64]
  25. ms_non_integral_types = [mstype.float16, mstype.float32, mstype.float64]
  26. def string_dataset_generator(strings):
  27. for string in strings:
  28. yield (np.array(string, dtype='S'),)
  29. def test_to_number_eager():
  30. """
  31. Test ToNumber op is callable
  32. """
  33. input_strings = [["1", "2", "3"], ["4", "5", "6"]]
  34. op = text.ToNumber(mstype.int8)
  35. # test input_strings as one 2D tensor
  36. result1 = op(input_strings) # np array: [[1 2 3] [4 5 6]]
  37. assert np.array_equal(result1, np.array([[1, 2, 3], [4, 5, 6]], dtype='i'))
  38. # test input multiple tensors
  39. with pytest.raises(RuntimeError) as info:
  40. # test input_strings as two 1D tensor. It's error because to_number is an OneToOne op
  41. _ = op(*input_strings)
  42. assert "The op is OneToOne, can only accept one tensor as input." in str(info.value)
  43. # test input invalid tensor
  44. invalid_input = [["1", "2", "3"], ["4", "5"]]
  45. with pytest.raises(TypeError) as info:
  46. _ = op(invalid_input)
  47. assert "Invalid user input. Got <class 'list'>: [['1', '2', '3'], ['4', '5']], cannot be converted into tensor" in \
  48. str(info.value)
  49. def test_to_number_typical_case_integral():
  50. input_strings = [["-121", "14"], ["-2219", "7623"], ["-8162536", "162371864"],
  51. ["-1726483716", "98921728421"]]
  52. for ms_type, inputs in zip(ms_integral_types, input_strings):
  53. dataset = ds.GeneratorDataset(string_dataset_generator(inputs), "strings")
  54. dataset = dataset.map(operations=text.ToNumber(ms_type), input_columns=["strings"])
  55. expected_output = [int(string) for string in inputs]
  56. output = []
  57. for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  58. output.append(data["strings"])
  59. assert output == expected_output
  60. def test_to_number_typical_case_non_integral():
  61. input_strings = [["-1.1", "1.4"], ["-2219.321", "7623.453"], ["-816256.234282", "162371864.243243"]]
  62. epsilons = [0.001, 0.001, 0.0001, 0.0001, 0.0000001, 0.0000001]
  63. for ms_type, inputs in zip(ms_non_integral_types, input_strings):
  64. dataset = ds.GeneratorDataset(string_dataset_generator(inputs), "strings")
  65. dataset = dataset.map(operations=text.ToNumber(ms_type), input_columns=["strings"])
  66. expected_output = [float(string) for string in inputs]
  67. output = []
  68. for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  69. output.append(data["strings"])
  70. for expected, actual, epsilon in zip(expected_output, output, epsilons):
  71. assert abs(expected - actual) < epsilon
  72. def out_of_bounds_error_message_check(dataset, np_type, value_to_cast):
  73. type_info = np.iinfo(np_type)
  74. type_max = str(type_info.max)
  75. type_min = str(type_info.min)
  76. type_name = str(np.dtype(np_type))
  77. with pytest.raises(RuntimeError) as info:
  78. for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  79. pass
  80. assert "string input " + value_to_cast + " will be out of bounds if cast to " + type_name in str(info.value)
  81. assert "valid range is: [" + type_min + ", " + type_max + "]" in str(info.value)
  82. def test_to_number_out_of_bounds_integral():
  83. for np_type, ms_type in zip(np_integral_types, ms_integral_types):
  84. type_info = np.iinfo(np_type)
  85. input_strings = [str(type_info.max + 10)]
  86. dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
  87. dataset = dataset.map(operations=text.ToNumber(ms_type), input_columns=["strings"])
  88. out_of_bounds_error_message_check(dataset, np_type, input_strings[0])
  89. input_strings = [str(type_info.min - 10)]
  90. dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
  91. dataset = dataset.map(operations=text.ToNumber(ms_type), input_columns=["strings"])
  92. out_of_bounds_error_message_check(dataset, np_type, input_strings[0])
  93. def test_to_number_out_of_bounds_non_integral():
  94. above_range = [str(np.finfo(np.float16).max * 10), str(np.finfo(np.float32).max * 10), "1.8e+308"]
  95. input_strings = [above_range[0]]
  96. dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
  97. dataset = dataset.map(operations=text.ToNumber(ms_non_integral_types[0]), input_columns=["strings"])
  98. with pytest.raises(RuntimeError) as info:
  99. for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  100. pass
  101. assert "outside of valid float16 range" in str(info.value)
  102. input_strings = [above_range[1]]
  103. dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
  104. dataset = dataset.map(operations=text.ToNumber(ms_non_integral_types[1]), input_columns=["strings"])
  105. with pytest.raises(RuntimeError) as info:
  106. for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  107. pass
  108. assert "string input " + input_strings[0] + " will be out of bounds if cast to float32" in str(info.value)
  109. input_strings = [above_range[2]]
  110. dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
  111. dataset = dataset.map(operations=text.ToNumber(ms_non_integral_types[2]), input_columns=["strings"])
  112. with pytest.raises(RuntimeError) as info:
  113. for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  114. pass
  115. assert "string input " + input_strings[0] + " will be out of bounds if cast to float64" in str(info.value)
  116. below_range = [str(np.finfo(np.float16).min * 10), str(np.finfo(np.float32).min * 10), "-1.8e+308"]
  117. input_strings = [below_range[0]]
  118. dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
  119. dataset = dataset.map(operations=text.ToNumber(ms_non_integral_types[0]), input_columns=["strings"])
  120. with pytest.raises(RuntimeError) as info:
  121. for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  122. pass
  123. assert "outside of valid float16 range" in str(info.value)
  124. input_strings = [below_range[1]]
  125. dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
  126. dataset = dataset.map(operations=text.ToNumber(ms_non_integral_types[1]), input_columns=["strings"])
  127. with pytest.raises(RuntimeError) as info:
  128. for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  129. pass
  130. assert "string input " + input_strings[0] + " will be out of bounds if cast to float32" in str(info.value)
  131. input_strings = [below_range[2]]
  132. dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
  133. dataset = dataset.map(operations=text.ToNumber(ms_non_integral_types[2]), input_columns=["strings"])
  134. with pytest.raises(RuntimeError) as info:
  135. for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  136. pass
  137. assert "string input " + input_strings[0] + " will be out of bounds if cast to float64" in str(info.value)
  138. def test_to_number_boundaries_integral():
  139. for np_type, ms_type in zip(np_integral_types, ms_integral_types):
  140. type_info = np.iinfo(np_type)
  141. input_strings = [str(type_info.max)]
  142. dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
  143. dataset = dataset.map(operations=text.ToNumber(ms_type), input_columns=["strings"])
  144. for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  145. assert data["strings"] == int(input_strings[0])
  146. input_strings = [str(type_info.min)]
  147. dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
  148. dataset = dataset.map(operations=text.ToNumber(ms_type), input_columns=["strings"])
  149. for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  150. assert data["strings"] == int(input_strings[0])
  151. input_strings = [str(0)]
  152. dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
  153. dataset = dataset.map(operations=text.ToNumber(ms_type), input_columns=["strings"])
  154. for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  155. assert data["strings"] == int(input_strings[0])
  156. def test_to_number_invalid_input():
  157. input_strings = ["a8fa9ds8fa"]
  158. dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
  159. dataset = dataset.map(operations=text.ToNumber(mstype.int32), input_columns=["strings"])
  160. with pytest.raises(RuntimeError) as info:
  161. for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  162. pass
  163. assert "it is invalid to convert \"" + input_strings[0] + "\" to a number" in str(info.value)
  164. def test_to_number_invalid_type():
  165. with pytest.raises(TypeError) as info:
  166. dataset = ds.GeneratorDataset(string_dataset_generator(["a8fa9ds8fa"]), "strings")
  167. dataset = dataset.map(operations=text.ToNumber(mstype.bool_), input_columns=["strings"])
  168. assert "data_type: Bool is not numeric data type" in str(info.value)
  169. if __name__ == '__main__':
  170. test_to_number_eager()
  171. test_to_number_typical_case_integral()
  172. test_to_number_typical_case_non_integral()
  173. test_to_number_boundaries_integral()
  174. test_to_number_out_of_bounds_integral()
  175. test_to_number_out_of_bounds_non_integral()
  176. test_to_number_invalid_input()
  177. test_to_number_invalid_type()