You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_summary.py 8.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test model train """
  16. import os
  17. import re
  18. import tempfile
  19. import shutil
  20. from collections import Counter
  21. import pytest
  22. from mindspore import dataset as ds
  23. from mindspore import nn, Tensor, context
  24. from mindspore.nn.metrics import Loss
  25. from mindspore.nn.optim import Momentum
  26. from mindspore.dataset.transforms import c_transforms as C
  27. from mindspore.dataset.vision import c_transforms as CV
  28. from mindspore.dataset.vision import Inter
  29. from mindspore.common import dtype as mstype
  30. from mindspore.ops import operations as P
  31. from mindspore.common.initializer import Normal
  32. from mindspore.train import Model
  33. from mindspore.train.callback import SummaryCollector
  34. from tests.summary_utils import SummaryReader
  35. class LeNet5(nn.Cell):
  36. """
  37. Lenet network
  38. Args:
  39. num_class (int): Number of classes. Default: 10.
  40. num_channel (int): Number of channels. Default: 1.
  41. Returns:
  42. Tensor, output tensor
  43. Examples:
  44. >>> LeNet(num_class=10)
  45. """
  46. def __init__(self, num_class=10, num_channel=1, include_top=True):
  47. super(LeNet5, self).__init__()
  48. self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
  49. self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
  50. self.relu = nn.ReLU()
  51. self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
  52. self.include_top = include_top
  53. if self.include_top:
  54. self.flatten = nn.Flatten()
  55. self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
  56. self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
  57. self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
  58. self.scalar_summary = P.ScalarSummary()
  59. self.image_summary = P.ImageSummary()
  60. self.histogram_summary = P.HistogramSummary()
  61. self.tensor_summary = P.TensorSummary()
  62. self.channel = Tensor(num_channel)
  63. def construct(self, x):
  64. self.image_summary('image', x)
  65. x = self.conv1(x)
  66. self.histogram_summary('histogram', x)
  67. x = self.relu(x)
  68. self.tensor_summary('tensor', x)
  69. x = self.relu(x)
  70. x = self.max_pool2d(x)
  71. self.scalar_summary('scalar', self.channel)
  72. x = self.conv2(x)
  73. x = self.relu(x)
  74. x = self.max_pool2d(x)
  75. if not self.include_top:
  76. return x
  77. x = self.flatten(x)
  78. x = self.relu(self.fc1(x))
  79. x = self.relu(self.fc2(x))
  80. x = self.fc3(x)
  81. return x
  82. def create_dataset(data_path, num_samples=2):
  83. """create dataset for train or test"""
  84. num_parallel_workers = 1
  85. # define dataset
  86. mnist_ds = ds.MnistDataset(data_path, num_samples=num_samples)
  87. resize_height, resize_width = 32, 32
  88. rescale = 1.0 / 255.0
  89. rescale_nml = 1 / 0.3081
  90. shift_nml = -1 * 0.1307 / 0.3081
  91. # define map operations
  92. resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode
  93. rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
  94. rescale_op = CV.Rescale(rescale, shift=0.0)
  95. hwc2chw_op = CV.HWC2CHW()
  96. type_cast_op = C.TypeCast(mstype.int32)
  97. # apply map operations on images
  98. mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
  99. mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
  100. mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
  101. mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
  102. mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
  103. # apply DatasetOps
  104. mnist_ds = mnist_ds.shuffle(buffer_size=10000) # 10000 as in LeNet train script
  105. mnist_ds = mnist_ds.batch(batch_size=2, drop_remainder=True)
  106. return mnist_ds
  107. class TestSummary:
  108. """Test summary collector the basic function."""
  109. base_summary_dir = ''
  110. mnist_path = '/home/workspace/mindspore_dataset/mnist'
  111. @classmethod
  112. def setup_class(cls):
  113. """Run before test this class."""
  114. device_id = int(os.getenv('DEVICE_ID')) if os.getenv('DEVICE_ID') else 0
  115. context.set_context(mode=context.GRAPH_MODE, device_id=device_id)
  116. cls.base_summary_dir = tempfile.mkdtemp(suffix='summary')
  117. @classmethod
  118. def teardown_class(cls):
  119. """Run after test this class."""
  120. if os.path.exists(cls.base_summary_dir):
  121. shutil.rmtree(cls.base_summary_dir)
  122. def _run_network(self, dataset_sink_mode=False, num_samples=2):
  123. lenet = LeNet5()
  124. loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
  125. optim = Momentum(lenet.trainable_params(), learning_rate=0.1, momentum=0.9)
  126. model = Model(lenet, loss_fn=loss, optimizer=optim, metrics={'acc': Loss()})
  127. summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
  128. summary_collector = SummaryCollector(summary_dir=summary_dir, collect_freq=2)
  129. ds_train = create_dataset(os.path.join(self.mnist_path, "train"), num_samples=num_samples)
  130. model.train(1, ds_train, callbacks=[summary_collector], dataset_sink_mode=dataset_sink_mode)
  131. ds_eval = create_dataset(os.path.join(self.mnist_path, "test"))
  132. model.eval(ds_eval, dataset_sink_mode=dataset_sink_mode, callbacks=[summary_collector])
  133. return summary_dir
  134. @pytest.mark.level0
  135. @pytest.mark.platform_x86_ascend_training
  136. @pytest.mark.platform_arm_ascend_training
  137. @pytest.mark.env_onecard
  138. def test_summary_with_sink_mode_false(self):
  139. """Test summary with sink mode false, and num samples is 64."""
  140. summary_dir = self._run_network(num_samples=10)
  141. tag_list = self._list_summary_tags(summary_dir)
  142. expected_tag_set = {'conv1.weight/auto', 'conv2.weight/auto', 'fc1.weight/auto', 'fc1.bias/auto',
  143. 'fc2.weight/auto', 'input_data/auto', 'loss/auto',
  144. 'histogram', 'image', 'scalar', 'tensor'}
  145. assert set(expected_tag_set) == set(tag_list)
  146. # num samples is 10, batch size is 2, so step is 5, collect freq is 2,
  147. # SummaryCollector will collect the first step and 2th, 4th step
  148. tag_count = 3
  149. for value in Counter(tag_list).values():
  150. assert value == tag_count
  151. @pytest.mark.level0
  152. @pytest.mark.platform_x86_ascend_training
  153. @pytest.mark.platform_arm_ascend_training
  154. @pytest.mark.env_onecard
  155. def test_summary_with_sink_mode_true(self):
  156. """Test summary with sink mode true, and num samples is 64."""
  157. summary_dir = self._run_network(dataset_sink_mode=True, num_samples=10)
  158. tag_list = self._list_summary_tags(summary_dir)
  159. # There will not record input data when dataset sink mode is True
  160. expected_tags = {'conv1.weight/auto', 'conv2.weight/auto', 'fc1.weight/auto', 'fc1.bias/auto',
  161. 'fc2.weight/auto', 'loss/auto', 'histogram', 'image', 'scalar', 'tensor'}
  162. assert set(expected_tags) == set(tag_list)
  163. tag_count = 1
  164. for value in Counter(tag_list).values():
  165. assert value == tag_count
  166. @staticmethod
  167. def _list_summary_tags(summary_dir):
  168. summary_file_path = ''
  169. for file in os.listdir(summary_dir):
  170. if re.search("_MS", file):
  171. summary_file_path = os.path.join(summary_dir, file)
  172. break
  173. assert summary_file_path
  174. tags = list()
  175. with SummaryReader(summary_file_path) as summary_reader:
  176. while True:
  177. summary_event = summary_reader.read_event()
  178. if not summary_event:
  179. break
  180. for value in summary_event.summary.value:
  181. tags.append(value.tag)
  182. return tags