You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_histogram_summary.py 6.7 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Test histogram summary."""
  16. import logging
  17. import os
  18. import tempfile
  19. import numpy as np
  20. from mindspore.common.tensor import Tensor
  21. from mindspore.train.summary._summary_adapter import _calc_histogram_bins
  22. from mindspore.train.summary.summary_record import SummaryRecord, _cache_summary_tensor_data
  23. from tests.summary_utils import SummaryReader
  24. from tests.security_utils import security_off_wrap
  25. CUR_DIR = os.getcwd()
  26. SUMMARY_DIR = os.path.join(CUR_DIR, "/test_temp_summary_event_file/")
  27. LOG = logging.getLogger("test")
  28. LOG.setLevel(level=logging.ERROR)
  29. def _wrap_test_data(input_data: Tensor):
  30. """
  31. Wraps test data to summary format.
  32. Args:
  33. input_data (Tensor): Input data.
  34. Returns:
  35. dict, the wrapped data.
  36. """
  37. return [{
  38. "name": "test_data[:Histogram]",
  39. "data": input_data
  40. }]
  41. @security_off_wrap
  42. def test_histogram_summary():
  43. """Test histogram summary."""
  44. with tempfile.TemporaryDirectory() as tmp_dir:
  45. with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer:
  46. test_data = _wrap_test_data(Tensor([[1, 2, 3], [4, 5, 6]]))
  47. _cache_summary_tensor_data(test_data)
  48. test_writer.record(step=1)
  49. file_name = os.path.realpath(test_writer.log_dir)
  50. with SummaryReader(file_name) as reader:
  51. event = reader.read_event()
  52. assert event.summary.value[0].histogram.count == 6
  53. @security_off_wrap
  54. def test_histogram_multi_summary():
  55. """Test histogram multiple step."""
  56. with tempfile.TemporaryDirectory() as tmp_dir:
  57. with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer:
  58. rng = np.random.RandomState(10)
  59. size = 50
  60. num_step = 5
  61. for i in range(num_step):
  62. arr = rng.normal(size=size)
  63. test_data = _wrap_test_data(Tensor(arr))
  64. _cache_summary_tensor_data(test_data)
  65. test_writer.record(step=i)
  66. file_name = os.path.realpath(test_writer.log_dir)
  67. with SummaryReader(file_name) as reader:
  68. for _ in range(num_step):
  69. event = reader.read_event()
  70. assert event.summary.value[0].histogram.count == size
  71. @security_off_wrap
  72. def test_histogram_summary_empty_tensor():
  73. """Test histogram summary, input is an empty tensor."""
  74. with tempfile.TemporaryDirectory() as tmp_dir:
  75. with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer:
  76. test_data = _wrap_test_data(Tensor([]))
  77. _cache_summary_tensor_data(test_data)
  78. test_writer.record(step=1)
  79. file_name = os.path.realpath(test_writer.log_dir)
  80. with SummaryReader(file_name) as reader:
  81. event = reader.read_event()
  82. assert event.summary.value[0].histogram.count == 0
  83. @security_off_wrap
  84. def test_histogram_summary_same_value():
  85. """Test histogram summary, input is an ones tensor."""
  86. with tempfile.TemporaryDirectory() as tmp_dir:
  87. with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer:
  88. dim1 = 100
  89. dim2 = 100
  90. test_data = _wrap_test_data(Tensor(np.ones([dim1, dim2])))
  91. _cache_summary_tensor_data(test_data)
  92. test_writer.record(step=1)
  93. file_name = os.path.realpath(test_writer.log_dir)
  94. with SummaryReader(file_name) as reader:
  95. event = reader.read_event()
  96. LOG.debug(event)
  97. assert len(event.summary.value[0].histogram.buckets) == _calc_histogram_bins(dim1 * dim2)
  98. @security_off_wrap
  99. def test_histogram_summary_high_dims():
  100. """Test histogram summary, input is a 4-dimension tensor."""
  101. with tempfile.TemporaryDirectory() as tmp_dir:
  102. with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer:
  103. dim = 10
  104. rng = np.random.RandomState(0)
  105. tensor_data = rng.normal(size=[dim, dim, dim, dim])
  106. test_data = _wrap_test_data(Tensor(tensor_data))
  107. _cache_summary_tensor_data(test_data)
  108. test_writer.record(step=1)
  109. file_name = os.path.realpath(test_writer.log_dir)
  110. with SummaryReader(file_name) as reader:
  111. event = reader.read_event()
  112. LOG.debug(event)
  113. assert event.summary.value[0].histogram.count == tensor_data.size
  114. @security_off_wrap
  115. def test_histogram_summary_nan_inf():
  116. """Test histogram summary, input tensor has nan."""
  117. with tempfile.TemporaryDirectory() as tmp_dir:
  118. with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer:
  119. dim1 = 100
  120. dim2 = 100
  121. arr = np.ones([dim1, dim2])
  122. arr[0][0] = np.nan
  123. arr[0][1] = np.inf
  124. arr[0][2] = -np.inf
  125. test_data = _wrap_test_data(Tensor(arr))
  126. _cache_summary_tensor_data(test_data)
  127. test_writer.record(step=1)
  128. file_name = os.path.realpath(test_writer.log_dir)
  129. with SummaryReader(file_name) as reader:
  130. event = reader.read_event()
  131. LOG.debug(event)
  132. assert event.summary.value[0].histogram.nan_count == 1
  133. @security_off_wrap
  134. def test_histogram_summary_all_nan_inf():
  135. """Test histogram summary, input tensor has no valid number."""
  136. with tempfile.TemporaryDirectory() as tmp_dir:
  137. with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer:
  138. test_data = _wrap_test_data(Tensor(np.array([np.nan, np.nan, np.nan, np.inf, -np.inf])))
  139. _cache_summary_tensor_data(test_data)
  140. test_writer.record(step=1)
  141. file_name = os.path.realpath(test_writer.log_dir)
  142. with SummaryReader(file_name) as reader:
  143. event = reader.read_event()
  144. LOG.debug(event)
  145. histogram = event.summary.value[0].histogram
  146. assert histogram.nan_count == 3
  147. assert histogram.pos_inf_count == 1
  148. assert histogram.neg_inf_count == 1