You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_summary_collector.py 8.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """test SummaryCollector."""
  16. import os
  17. import re
  18. import shutil
  19. import tempfile
  20. from collections import Counter
  21. import pytest
  22. from mindspore import nn, Tensor, context
  23. from mindspore.common.initializer import Normal
  24. from mindspore.nn.metrics import Loss
  25. from mindspore.nn.optim import Momentum
  26. from mindspore.ops import operations as P
  27. from mindspore.train import Model
  28. from mindspore.train.callback import SummaryCollector
  29. from tests.st.summary.dataset import create_mnist_dataset
  30. from tests.summary_utils import SummaryReader
  31. class LeNet5(nn.Cell):
  32. """
  33. Lenet network
  34. Args:
  35. num_class (int): Number of classes. Default: 10.
  36. num_channel (int): Number of channels. Default: 1.
  37. Returns:
  38. Tensor, output tensor
  39. Examples:
  40. >>> LeNet(num_class=10)
  41. """
  42. def __init__(self, num_class=10, num_channel=1, include_top=True):
  43. super(LeNet5, self).__init__()
  44. self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
  45. self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
  46. self.relu = nn.ReLU()
  47. self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
  48. self.include_top = include_top
  49. if self.include_top:
  50. self.flatten = nn.Flatten()
  51. self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
  52. self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
  53. self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
  54. self.scalar_summary = P.ScalarSummary()
  55. self.image_summary = P.ImageSummary()
  56. self.histogram_summary = P.HistogramSummary()
  57. self.tensor_summary = P.TensorSummary()
  58. self.channel = Tensor(num_channel)
  59. def construct(self, x):
  60. """construct."""
  61. self.image_summary('image', x)
  62. x = self.conv1(x)
  63. self.histogram_summary('histogram', x)
  64. x = self.relu(x)
  65. self.tensor_summary('tensor', x)
  66. x = self.relu(x)
  67. x = self.max_pool2d(x)
  68. self.scalar_summary('scalar', self.channel)
  69. x = self.conv2(x)
  70. x = self.relu(x)
  71. x = self.max_pool2d(x)
  72. if not self.include_top:
  73. return x
  74. x = self.flatten(x)
  75. x = self.relu(self.fc1(x))
  76. x = self.relu(self.fc2(x))
  77. x = self.fc3(x)
  78. return x
  79. class TestSummary:
  80. """Test summary collector the basic function."""
  81. base_summary_dir = ''
  82. @classmethod
  83. def setup_class(cls):
  84. """Run before test this class."""
  85. device_id = int(os.getenv('DEVICE_ID')) if os.getenv('DEVICE_ID') else 0
  86. context.set_context(mode=context.GRAPH_MODE, device_id=device_id)
  87. cls.base_summary_dir = tempfile.mkdtemp(suffix='summary')
  88. @classmethod
  89. def teardown_class(cls):
  90. """Run after test this class."""
  91. if os.path.exists(cls.base_summary_dir):
  92. shutil.rmtree(cls.base_summary_dir)
  93. def _run_network(self, dataset_sink_mode=False, num_samples=2, **kwargs):
  94. """run network."""
  95. lenet = LeNet5()
  96. loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
  97. optim = Momentum(lenet.trainable_params(), learning_rate=0.1, momentum=0.9)
  98. model = Model(lenet, loss_fn=loss, optimizer=optim, metrics={'loss': Loss()})
  99. summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
  100. summary_collector = SummaryCollector(summary_dir=summary_dir, collect_freq=2, **kwargs)
  101. ds_train = create_mnist_dataset("train", num_samples=num_samples)
  102. model.train(1, ds_train, callbacks=[summary_collector], dataset_sink_mode=dataset_sink_mode)
  103. ds_eval = create_mnist_dataset("test")
  104. model.eval(ds_eval, dataset_sink_mode=dataset_sink_mode, callbacks=[summary_collector])
  105. return summary_dir
  106. @pytest.mark.level0
  107. @pytest.mark.platform_x86_ascend_training
  108. @pytest.mark.platform_arm_ascend_training
  109. @pytest.mark.platform_x86_gpu_training
  110. @pytest.mark.env_onecard
  111. def test_summary_with_sink_mode_false(self):
  112. """Test summary with sink mode false, and num samples is 64."""
  113. summary_dir = self._run_network(num_samples=10)
  114. tag_list = self._list_summary_tags(summary_dir)
  115. expected_tag_set = {'conv1.weight/auto', 'conv2.weight/auto', 'fc1.weight/auto', 'fc1.bias/auto',
  116. 'fc2.weight/auto', 'input_data/auto', 'loss/auto',
  117. 'histogram', 'image', 'scalar', 'tensor'}
  118. assert set(expected_tag_set) == set(tag_list)
  119. # num samples is 10, batch size is 2, so step is 5, collect freq is 2,
  120. # SummaryCollector will collect the first step and 2th, 4th step
  121. tag_count = 3
  122. for value in Counter(tag_list).values():
  123. assert value == tag_count
  124. @pytest.mark.level0
  125. @pytest.mark.platform_x86_ascend_training
  126. @pytest.mark.platform_arm_ascend_training
  127. @pytest.mark.platform_x86_gpu_training
  128. @pytest.mark.env_onecard
  129. def test_summary_with_sink_mode_true(self):
  130. """Test summary with sink mode true, and num samples is 64."""
  131. summary_dir = self._run_network(dataset_sink_mode=True, num_samples=10)
  132. tag_list = self._list_summary_tags(summary_dir)
  133. # There will not record input data when dataset sink mode is True
  134. expected_tags = {'conv1.weight/auto', 'conv2.weight/auto', 'fc1.weight/auto', 'fc1.bias/auto',
  135. 'fc2.weight/auto', 'loss/auto', 'histogram', 'image', 'scalar', 'tensor'}
  136. assert set(expected_tags) == set(tag_list)
  137. tag_count = 1
  138. for value in Counter(tag_list).values():
  139. assert value == tag_count
  140. @pytest.mark.level0
  141. @pytest.mark.platform_x86_ascend_training
  142. @pytest.mark.platform_arm_ascend_training
  143. @pytest.mark.env_onecard
  144. def test_summarycollector_user_defind(self):
  145. """Test SummaryCollector with user-defined."""
  146. summary_dir = self._run_network(dataset_sink_mode=True, num_samples=2,
  147. custom_lineage_data={'test': 'self test'},
  148. export_options={'tensor_format': 'npy'})
  149. tag_list = self._list_summary_tags(summary_dir)
  150. file_list = self._list_tensor_files(summary_dir)
  151. # There will not record input data when dataset sink mode is True
  152. expected_tags = {'conv1.weight/auto', 'conv2.weight/auto', 'fc1.weight/auto', 'fc1.bias/auto',
  153. 'fc2.weight/auto', 'loss/auto', 'histogram', 'image', 'scalar', 'tensor'}
  154. assert set(expected_tags) == set(tag_list)
  155. expected_files = {'tensor_1.npy'}
  156. assert set(expected_files) == set(file_list)
  157. @staticmethod
  158. def _list_summary_tags(summary_dir):
  159. """list summary tags."""
  160. summary_file_path = ''
  161. for file in os.listdir(summary_dir):
  162. if re.search("_MS", file):
  163. summary_file_path = os.path.join(summary_dir, file)
  164. break
  165. assert summary_file_path
  166. tags = list()
  167. with SummaryReader(summary_file_path) as summary_reader:
  168. while True:
  169. summary_event = summary_reader.read_event()
  170. if not summary_event:
  171. break
  172. for value in summary_event.summary.value:
  173. tags.append(value.tag)
  174. return tags
  175. @staticmethod
  176. def _list_tensor_files(summary_dir):
  177. """list tensor tags."""
  178. export_file_path = ''
  179. for file in os.listdir(summary_dir):
  180. if re.search("export_", file):
  181. export_file_path = os.path.join(summary_dir, file)
  182. break
  183. assert export_file_path
  184. tensor_file_path = os.path.join(export_file_path, "tensor")
  185. assert tensor_file_path
  186. tensors = list()
  187. for file in os.listdir(tensor_file_path):
  188. tensors.append(file)
  189. return tensors