You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

visualizer.py 7.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. """
  2. Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
  3. Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
  4. """
  5. import os
  6. import ntpath
  7. import time
  8. from . import util
  9. from . import html
  10. import scipy.misc
  11. try:
  12. from StringIO import StringIO # Python 2.7
  13. except ImportError:
  14. from io import BytesIO # Python 3.x
  15. class Visualizer():
  16. def __init__(self, opt):
  17. self.opt = opt
  18. self.tf_log = opt.isTrain and opt.tf_log
  19. self.use_html = opt.isTrain and not opt.no_html
  20. self.win_size = opt.display_winsize
  21. self.name = opt.name
  22. if self.tf_log:
  23. import tensorflow as tf
  24. self.tf = tf
  25. self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs')
  26. self.writer = tf.summary.FileWriter(self.log_dir)
  27. if self.use_html:
  28. self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
  29. self.img_dir = os.path.join(self.web_dir, 'images')
  30. print('create web directory %s...' % self.web_dir)
  31. util.mkdirs([self.web_dir, self.img_dir])
  32. if opt.isTrain:
  33. self.log_name = os.path.join(
  34. opt.checkpoints_dir, opt.name, 'loss_log.txt')
  35. with open(self.log_name, "a") as log_file:
  36. now = time.strftime("%c")
  37. log_file.write(
  38. '================ Training Loss (%s) ================\n' % now)
  39. # |visuals|: dictionary of images to display or save
  40. def display_current_results(self, visuals, epoch, step):
  41. # convert tensors to numpy arrays
  42. visuals = self.convert_visuals_to_numpy(visuals)
  43. if self.tf_log: # show images in tensorboard output
  44. img_summaries = []
  45. for label, image_numpy in visuals.items():
  46. # Write the image to a string
  47. try:
  48. s = StringIO()
  49. except:
  50. s = BytesIO()
  51. if len(image_numpy.shape) >= 4:
  52. image_numpy = image_numpy[0]
  53. scipy.misc.toimage(image_numpy).save(s, format="jpeg")
  54. # Create an Image object
  55. img_sum = self.tf.Summary.Image(encoded_image_string=s.getvalue(
  56. ), height=image_numpy.shape[0], width=image_numpy.shape[1])
  57. # Create a Summary value
  58. img_summaries.append(
  59. self.tf.Summary.Value(tag=label, image=img_sum))
  60. # Create and write Summary
  61. summary = self.tf.Summary(value=img_summaries)
  62. self.writer.add_summary(summary, step)
  63. if self.use_html: # save images to a html file
  64. for label, image_numpy in visuals.items():
  65. if isinstance(image_numpy, list):
  66. for i in range(len(image_numpy)):
  67. img_path = os.path.join(
  68. self.img_dir, 'epoch%.3d_iter%.3d_%s_%d.png' % (epoch, step, label, i))
  69. util.save_image(image_numpy[i], img_path)
  70. else:
  71. img_path = os.path.join(
  72. self.img_dir, 'epoch%.3d_iter%.3d_%s.png' % (epoch, step, label))
  73. if len(image_numpy.shape) >= 4:
  74. image_numpy = image_numpy[0]
  75. util.save_image(image_numpy, img_path)
  76. # update website
  77. webpage = html.HTML(
  78. self.web_dir, 'Experiment name = %s' % self.name, refresh=5)
  79. for n in range(epoch, 0, -1):
  80. webpage.add_header('epoch [%d]' % n)
  81. ims = []
  82. txts = []
  83. links = []
  84. for label, image_numpy in visuals.items():
  85. if isinstance(image_numpy, list):
  86. for i in range(len(image_numpy)):
  87. img_path = 'epoch%.3d_iter%.3d_%s_%d.png' % (
  88. n, step, label, i)
  89. ims.append(img_path)
  90. txts.append(label+str(i))
  91. links.append(img_path)
  92. else:
  93. img_path = 'epoch%.3d_iter%.3d_%s.png' % (
  94. n, step, label)
  95. ims.append(img_path)
  96. txts.append(label)
  97. links.append(img_path)
  98. if len(ims) < 10:
  99. webpage.add_images(ims, txts, links, width=self.win_size)
  100. else:
  101. num = int(round(len(ims)/2.0))
  102. webpage.add_images(
  103. ims[:num], txts[:num], links[:num], width=self.win_size)
  104. webpage.add_images(
  105. ims[num:], txts[num:], links[num:], width=self.win_size)
  106. webpage.save()
  107. # errors: dictionary of error labels and values
  108. def plot_current_errors(self, errors, step):
  109. if self.tf_log:
  110. for tag, value in errors.items():
  111. value = value.mean().float()
  112. summary = self.tf.Summary(
  113. value=[self.tf.Summary.Value(tag=tag, simple_value=value)])
  114. self.writer.add_summary(summary, step)
  115. # errors: same format as |errors| of plotCurrentErrors
  116. def print_current_errors(self, epoch, i, errors, t):
  117. message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t)
  118. for k, v in errors.items():
  119. # print(v)
  120. # if v != 0:
  121. v = v.mean().float()
  122. message += '%s: %.3f ' % (k, v)
  123. print(message)
  124. with open(self.log_name, "a") as log_file:
  125. log_file.write('%s\n' % message)
  126. def convert_visuals_to_numpy(self, visuals):
  127. for key, t in visuals.items():
  128. tile = self.opt.batchSize > 8
  129. if 'input_label' == key:
  130. t = util.tensor2label(t, self.opt.label_nc + 2, tile=tile)
  131. else:
  132. t = util.tensor2im(t, tile=tile)
  133. visuals[key] = t
  134. return visuals
  135. # save image to the disk
  136. def save_images(self, webpage, visuals, image_path):
  137. visuals = self.convert_visuals_to_numpy(visuals)
  138. image_dir = webpage.get_image_dir()
  139. short_path = ntpath.basename(image_path[0])
  140. name = os.path.splitext(short_path)[0]
  141. webpage.add_header(name)
  142. ims = []
  143. txts = []
  144. links = []
  145. for label, image_numpy in visuals.items():
  146. if label == 'input_label':
  147. image_name = os.path.join(label, '%s.png' % (name))
  148. else:
  149. image_name = os.path.join(label, '%s.jpg' % (name))
  150. save_path = os.path.join(image_dir, image_name)
  151. util.save_image(image_numpy, save_path, create_dir=True)
  152. ims.append(image_name)
  153. txts.append(label)
  154. links.append(image_name)
  155. webpage.add_images(ims, txts, links, width=self.win_size)

第三届计图人工智能挑战赛——风格及语义引导的风景图片生成赛道项目,由jittor计图框架实现