You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_image_summary.py 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. @File : test_image_summary.py
  17. @Author:
  18. @Date : 2019-07-4
  19. @Desc : test summary function
  20. """
  21. import os
  22. import logging
  23. import numpy as np
  24. import mindspore.nn as nn
  25. from mindspore.train.summary.summary_record import SummaryRecord, \
  26. _cache_summary_tensor_data
  27. from mindspore import Tensor
  28. from mindspore.nn.optim import Momentum
  29. from mindspore import Model, context
  30. from mindspore.train.callback import SummaryStep
  31. from .....dataset_mock import MindData
  32. CUR_DIR = os.getcwd()
  33. SUMMARY_DIR = CUR_DIR + "/test_temp_summary_event_file/"
  34. log = logging.getLogger("test")
  35. log.setLevel(level=logging.ERROR)
  36. def make_image_tensor(shape, dtype=float):
  37. """ make_image_tensor """
  38. # pylint: disable=unused-argument
  39. numel = np.prod(shape)
  40. x = (np.arange(numel, dtype=float)).reshape(shape)
  41. return x
  42. def get_test_data(step):
  43. """ get_test_data """
  44. test_data_list = []
  45. tag1 = "x1[:Image]"
  46. tag2 = "x2[:Image]"
  47. np1 = make_image_tensor([2, 3, 8, 8])
  48. np2 = make_image_tensor([step, 3, 8, 8])
  49. dict1 = {}
  50. dict1["name"] = tag1
  51. dict1["data"] = Tensor(np1)
  52. dict2 = {}
  53. dict2["name"] = tag2
  54. dict2["data"] = Tensor(np2)
  55. test_data_list.append(dict1)
  56. test_data_list.append(dict2)
  57. return test_data_list
  58. # Test: call method on parse graph code
  59. def test_image_summary_sample():
  60. """ test_image_summary_sample """
  61. log.debug("begin test_image_summary_sample")
  62. # step 0: create the thread
  63. test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE")
  64. # step 1: create the test data for summary
  65. # step 2: create the Event
  66. for i in range(1, 5):
  67. test_data = get_test_data(i)
  68. _cache_summary_tensor_data(test_data)
  69. test_writer.record(i)
  70. test_writer.flush()
  71. # step 3: send the event to mq
  72. # step 4: accept the event and write the file
  73. test_writer.close()
  74. log.debug("finished test_image_summary_sample")
  75. class Net(nn.Cell):
  76. """ Net definition """
  77. def __init__(self):
  78. super(Net, self).__init__()
  79. self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal',
  80. pad_mode='valid')
  81. self.bn = nn.BatchNorm2d(64)
  82. self.relu = nn.ReLU()
  83. self.flatten = nn.Flatten()
  84. self.fc = nn.Dense(64 * 222 * 222, 3) # padding=0
  85. def construct(self, x):
  86. x = self.conv(x)
  87. x = self.bn(x)
  88. x = self.relu(x)
  89. x = self.flatten(x)
  90. out = self.fc(x)
  91. return out
  92. class LossNet(nn.Cell):
  93. """ LossNet definition """
  94. def __init__(self):
  95. super(LossNet, self).__init__()
  96. self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal',
  97. pad_mode='valid')
  98. self.bn = nn.BatchNorm2d(64)
  99. self.relu = nn.ReLU()
  100. self.flatten = nn.Flatten()
  101. self.fc = nn.Dense(64 * 222 * 222, 3) # padding=0
  102. self.loss = nn.SoftmaxCrossEntropyWithLogits()
  103. def construct(self, x, y):
  104. x = self.conv(x)
  105. x = self.bn(x)
  106. x = self.relu(x)
  107. x = self.flatten(x)
  108. x = self.fc(x)
  109. out = self.loss(x, y)
  110. return out
  111. def get_model():
  112. """ get_model """
  113. net = Net()
  114. loss = nn.SoftmaxCrossEntropyWithLogits()
  115. optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  116. context.set_context(mode=context.GRAPH_MODE)
  117. model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
  118. return model
  119. def get_dataset():
  120. """ get_dataset """
  121. dataset_types = (np.float32, np.float32)
  122. dataset_shapes = ((2, 3, 224, 224), (2, 3))
  123. dataset = MindData(size=2, batch_size=2,
  124. np_types=dataset_types,
  125. output_shapes=dataset_shapes,
  126. input_indexs=(0, 1))
  127. return dataset
  128. class ImageSummaryCallback:
  129. def __init__(self, summaryRecord):
  130. self._summaryRecord = summaryRecord
  131. def record(self, step, train_network=None):
  132. self._summaryRecord.record(step, train_network)
  133. self._summaryRecord.flush()
  134. def test_image_summary_train():
  135. """ test_image_summary_train """
  136. dataset = get_dataset()
  137. log.debug("begin test_image_summary_sample")
  138. # step 0: create the thread
  139. test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE")
  140. # step 1: create the test data for summary
  141. # step 2: create the Event
  142. model = get_model()
  143. fn = ImageSummaryCallback(test_writer)
  144. summary_recode = SummaryStep(fn, 1)
  145. model.train(2, dataset, callbacks=summary_recode)
  146. # step 3: send the event to mq
  147. # step 4: accept the event and write the file
  148. test_writer.close()
  149. log.debug("finished test_image_summary_sample")
  150. def test_image_summary_data():
  151. """ test_image_summary_data """
  152. dataset = get_dataset()
  153. test_data_list = []
  154. i = 1
  155. for next_element in dataset:
  156. tag = "image_" + str(i) + "[:Image]"
  157. dct = {}
  158. dct["name"] = tag
  159. dct["data"] = Tensor(next_element[0])
  160. test_data_list.append(dct)
  161. i += 1
  162. log.debug("begin test_image_summary_sample")
  163. # step 0: create the thread
  164. test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE")
  165. # step 1: create the test data for summary
  166. # step 2: create the Event
  167. _cache_summary_tensor_data(test_data_list)
  168. test_writer.record(1)
  169. test_writer.flush()
  170. # step 3: send the event to mq
  171. # step 4: accept the event and write the file
  172. test_writer.close()
  173. log.debug("finished test_image_summary_sample")