You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

visualize.py 22 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import os
  4. import imageio
  5. import numpy as np
  6. import tensorlayer as tl
  7. from tensorlayer.lazy_imports import LazyImport
  8. cv2 = LazyImport("cv2")
  9. # Uncomment the following line if you got: _tkinter.TclError: no display name and no $DISPLAY environment variable
  10. # import matplotlib
  11. # matplotlib.use('Agg')
  12. __all__ = [
  13. 'read_image',
  14. 'read_images',
  15. 'save_image',
  16. 'save_images',
  17. 'draw_boxes_and_labels_to_image',
  18. 'draw_mpii_people_to_image',
  19. 'frame',
  20. 'CNN2d',
  21. 'images2d',
  22. 'tsne_embedding',
  23. 'draw_weights',
  24. 'W',
  25. ]
  26. def read_image(image, path=''):
  27. """Read one image.
  28. Parameters
  29. -----------
  30. image : str
  31. The image file name.
  32. path : str
  33. The image folder path.
  34. Returns
  35. -------
  36. numpy.array
  37. The image.
  38. """
  39. return imageio.imread(os.path.join(path, image))
  40. def read_images(img_list, path='', n_threads=10, printable=True):
  41. """Returns all images in list by given path and name of each image file.
  42. Parameters
  43. -------------
  44. img_list : list of str
  45. The image file names.
  46. path : str
  47. The image folder path.
  48. n_threads : int
  49. The number of threads to read image.
  50. printable : boolean
  51. Whether to print information when reading images.
  52. Returns
  53. -------
  54. list of numpy.array
  55. The images.
  56. """
  57. imgs = []
  58. for idx in range(0, len(img_list), n_threads):
  59. b_imgs_list = img_list[idx:idx + n_threads]
  60. b_imgs = tl.prepro.threading_data(b_imgs_list, fn=read_image, path=path)
  61. # tl.logging.info(b_imgs.shape)
  62. imgs.extend(b_imgs)
  63. if printable:
  64. tl.logging.info('read %d from %s' % (len(imgs), path))
  65. return imgs
  66. def save_image(image, image_path='_temp.png'):
  67. """Save a image.
  68. Parameters
  69. -----------
  70. image : numpy array
  71. [w, h, c]
  72. image_path : str
  73. path
  74. """
  75. try: # RGB
  76. imageio.imwrite(image_path, image)
  77. except Exception: # Greyscale
  78. imageio.imwrite(image_path, image[:, :, 0])
  79. def save_images(images, size, image_path='_temp.png'):
  80. """Save multiple images into one single image.
  81. Parameters
  82. -----------
  83. images : numpy array
  84. (batch, w, h, c)
  85. size : list of 2 ints
  86. row and column number.
  87. number of images should be equal or less than size[0] * size[1]
  88. image_path : str
  89. save path
  90. Examples
  91. ---------
  92. >>> import numpy as np
  93. >>> import tensorlayer as tl
  94. >>> images = np.random.rand(64, 100, 100, 3)
  95. >>> tl.visualize.save_images(images, [8, 8], 'temp.png')
  96. """
  97. if len(images.shape) == 3: # Greyscale [batch, h, w] --> [batch, h, w, 1]
  98. images = images[:, :, :, np.newaxis]
  99. def merge(images, size):
  100. h, w = images.shape[1], images.shape[2]
  101. img = np.zeros((h * size[0], w * size[1], 3), dtype=images.dtype)
  102. for idx, image in enumerate(images):
  103. i = idx % size[1]
  104. j = idx // size[1]
  105. img[j * h:j * h + h, i * w:i * w + w, :] = image
  106. return img
  107. def imsave(images, size, path):
  108. if np.max(images) <= 1 and (-1 <= np.min(images) < 0):
  109. images = ((images + 1) * 127.5).astype(np.uint8)
  110. elif np.max(images) <= 1 and np.min(images) >= 0:
  111. images = (images * 255).astype(np.uint8)
  112. return imageio.imwrite(path, merge(images, size))
  113. if len(images) > size[0] * size[1]:
  114. raise AssertionError("number of images should be equal or less than size[0] * size[1] {}".format(len(images)))
  115. return imsave(images, size, image_path)
  116. def draw_boxes_and_labels_to_image(
  117. image, classes, coords, scores, classes_list, is_center=True, is_rescale=True, save_name=None
  118. ):
  119. """Draw bboxes and class labels on image. Return or save the image with bboxes, example in the docs of ``tl.prepro``.
  120. Parameters
  121. -----------
  122. image : numpy.array
  123. The RGB image [height, width, channel].
  124. classes : list of int
  125. A list of class ID (int).
  126. coords : list of int
  127. A list of list for coordinates.
  128. - Should be [x, y, x2, y2] (up-left and botton-right format)
  129. - If [x_center, y_center, w, h] (set is_center to True).
  130. scores : list of float
  131. A list of score (float). (Optional)
  132. classes_list : list of str
  133. for converting ID to string on image.
  134. is_center : boolean
  135. Whether the coordinates is [x_center, y_center, w, h]
  136. - If coordinates are [x_center, y_center, w, h], set it to True for converting it to [x, y, x2, y2] (up-left and botton-right) internally.
  137. - If coordinates are [x1, x2, y1, y2], set it to False.
  138. is_rescale : boolean
  139. Whether to rescale the coordinates from pixel-unit format to ratio format.
  140. - If True, the input coordinates are the portion of width and high, this API will scale the coordinates to pixel unit internally.
  141. - If False, feed the coordinates with pixel unit format.
  142. save_name : None or str
  143. The name of image file (i.e. image.png), if None, not to save image.
  144. Returns
  145. -------
  146. numpy.array
  147. The saved image.
  148. References
  149. -----------
  150. - OpenCV rectangle and putText.
  151. - `scikit-image <http://scikit-image.org/docs/dev/api/skimage.draw.html#skimage.draw.rectangle>`__.
  152. """
  153. if len(coords) != len(classes):
  154. raise AssertionError("number of coordinates and classes are equal")
  155. if len(scores) > 0 and len(scores) != len(classes):
  156. raise AssertionError("number of scores and classes are equal")
  157. # don't change the original image, and avoid error https://stackoverflow.com/questions/30249053/python-opencv-drawing-errors-after-manipulating-array-with-numpy
  158. image = image.copy()
  159. imh, imw = image.shape[0:2]
  160. thick = int((imh + imw) // 430)
  161. for i, _v in enumerate(coords):
  162. if is_center:
  163. x, y, x2, y2 = tl.prepro.obj_box_coord_centroid_to_upleft_butright(coords[i])
  164. else:
  165. x, y, x2, y2 = coords[i]
  166. if is_rescale: # scale back to pixel unit if the coords are the portion of width and high
  167. x, y, x2, y2 = tl.prepro.obj_box_coord_scale_to_pixelunit([x, y, x2, y2], (imh, imw))
  168. cv2.rectangle(
  169. image,
  170. (int(x), int(y)),
  171. (int(x2), int(y2)), # up-left and botton-right
  172. [0, 255, 0],
  173. thick
  174. )
  175. cv2.putText(
  176. image,
  177. classes_list[classes[i]] + ((" %.2f" % (scores[i])) if (len(scores) != 0) else " "),
  178. (int(x), int(y)), # button left
  179. 0,
  180. 1.5e-3 * imh, # bigger = larger font
  181. [0, 0, 256], # self.meta['colors'][max_indx],
  182. int(thick / 2) + 1
  183. ) # bold
  184. if save_name is not None:
  185. # cv2.imwrite('_my.png', image)
  186. save_image(image, save_name)
  187. # if len(coords) == 0:
  188. # tl.logging.info("draw_boxes_and_labels_to_image: no bboxes exist, cannot draw !")
  189. return image
  190. def draw_mpii_pose_to_image(image, poses, save_name='image.png'):
  191. """Draw people(s) into image using MPII dataset format as input, return or save the result image.
  192. This is an experimental API, can be changed in the future.
  193. Parameters
  194. -----------
  195. image : numpy.array
  196. The RGB image [height, width, channel].
  197. poses : list of dict
  198. The people(s) annotation in MPII format, see ``tl.files.load_mpii_pose_dataset``.
  199. save_name : None or str
  200. The name of image file (i.e. image.png), if None, not to save image.
  201. Returns
  202. --------
  203. numpy.array
  204. The saved image.
  205. Examples
  206. --------
  207. >>> import pprint
  208. >>> import tensorlayer as tl
  209. >>> img_train_list, ann_train_list, img_test_list, ann_test_list = tl.files.load_mpii_pose_dataset()
  210. >>> image = tl.vis.read_image(img_train_list[0])
  211. >>> tl.vis.draw_mpii_pose_to_image(image, ann_train_list[0], 'image.png')
  212. >>> pprint.pprint(ann_train_list[0])
  213. References
  214. -----------
  215. - `MPII Keyponts and ID <http://human-pose.mpi-inf.mpg.de/#download>`__
  216. """
  217. # import skimage
  218. # don't change the original image, and avoid error https://stackoverflow.com/questions/30249053/python-opencv-drawing-errors-after-manipulating-array-with-numpy
  219. image = image.copy()
  220. imh, imw = image.shape[0:2]
  221. thick = int((imh + imw) // 430)
  222. # radius = int(image.shape[1] / 500) + 1
  223. radius = int(thick * 1.5)
  224. if image.max() < 1:
  225. image = image * 255
  226. for people in poses:
  227. # Pose Keyponts
  228. joint_pos = people['joint_pos']
  229. # draw sketch
  230. # joint id (0 - r ankle, 1 - r knee, 2 - r hip, 3 - l hip, 4 - l knee,
  231. # 5 - l ankle, 6 - pelvis, 7 - thorax, 8 - upper neck,
  232. # 9 - head top, 10 - r wrist, 11 - r elbow, 12 - r shoulder,
  233. # 13 - l shoulder, 14 - l elbow, 15 - l wrist)
  234. #
  235. # 9
  236. # 8
  237. # 12 ** 7 ** 13
  238. # * * *
  239. # 11 * 14
  240. # * * *
  241. # 10 2 * 6 * 3 15
  242. # * *
  243. # 1 4
  244. # * *
  245. # 0 5
  246. lines = [
  247. [(0, 1), [100, 255, 100]],
  248. [(1, 2), [50, 255, 50]],
  249. [(2, 6), [0, 255, 0]], # right leg
  250. [(3, 4), [100, 100, 255]],
  251. [(4, 5), [50, 50, 255]],
  252. [(6, 3), [0, 0, 255]], # left leg
  253. [(6, 7), [255, 255, 100]],
  254. [(7, 8), [255, 150, 50]], # body
  255. [(8, 9), [255, 200, 100]], # head
  256. [(10, 11), [255, 100, 255]],
  257. [(11, 12), [255, 50, 255]],
  258. [(12, 8), [255, 0, 255]], # right hand
  259. [(8, 13), [0, 255, 255]],
  260. [(13, 14), [100, 255, 255]],
  261. [(14, 15), [200, 255, 255]] # left hand
  262. ]
  263. for line in lines:
  264. start, end = line[0]
  265. if (start in joint_pos) and (end in joint_pos):
  266. cv2.line(
  267. image,
  268. (int(joint_pos[start][0]), int(joint_pos[start][1])),
  269. (int(joint_pos[end][0]), int(joint_pos[end][1])), # up-left and botton-right
  270. line[1],
  271. thick
  272. )
  273. # rr, cc, val = skimage.draw.line_aa(int(joint_pos[start][1]), int(joint_pos[start][0]), int(joint_pos[end][1]), int(joint_pos[end][0]))
  274. # image[rr, cc] = line[1]
  275. # draw circles
  276. for pos in joint_pos.items():
  277. _, pos_loc = pos # pos_id, pos_loc
  278. pos_loc = (int(pos_loc[0]), int(pos_loc[1]))
  279. cv2.circle(image, center=pos_loc, radius=radius, color=(200, 200, 200), thickness=-1)
  280. # rr, cc = skimage.draw.circle(int(pos_loc[1]), int(pos_loc[0]), radius)
  281. # image[rr, cc] = [0, 255, 0]
  282. # Head
  283. head_rect = people['head_rect']
  284. if head_rect: # if head exists
  285. cv2.rectangle(
  286. image,
  287. (int(head_rect[0]), int(head_rect[1])),
  288. (int(head_rect[2]), int(head_rect[3])), # up-left and botton-right
  289. [0, 180, 0],
  290. thick
  291. )
  292. if save_name is not None:
  293. # cv2.imwrite(save_name, image)
  294. save_image(image, save_name)
  295. return image
  296. draw_mpii_people_to_image = draw_mpii_pose_to_image
  297. def frame(I=None, second=5, saveable=True, name='frame', cmap=None, fig_idx=12836):
  298. """Display a frame. Make sure OpenAI Gym render() is disable before using it.
  299. Parameters
  300. ----------
  301. I : numpy.array
  302. The image.
  303. second : int
  304. The display second(s) for the image(s), if saveable is False.
  305. saveable : boolean
  306. Save or plot the figure.
  307. name : str
  308. A name to save the image, if saveable is True.
  309. cmap : None or str
  310. 'gray' for greyscale, None for default, etc.
  311. fig_idx : int
  312. matplotlib figure index.
  313. Examples
  314. --------
  315. >>> env = gym.make("Pong-v0")
  316. >>> observation = env.reset()
  317. >>> tl.visualize.frame(observation)
  318. """
  319. import matplotlib.pyplot as plt
  320. if saveable is False:
  321. plt.ion()
  322. plt.figure(fig_idx) # show all feature images
  323. if len(I.shape) and I.shape[-1] == 1: # (10,10,1) --> (10,10)
  324. I = I[:, :, 0]
  325. plt.imshow(I, cmap)
  326. plt.title(name)
  327. # plt.gca().xaxis.set_major_locator(plt.NullLocator()) # distable tick
  328. # plt.gca().yaxis.set_major_locator(plt.NullLocator())
  329. if saveable:
  330. plt.savefig(name + '.pdf', format='pdf')
  331. else:
  332. plt.draw()
  333. plt.pause(second)
  334. def CNN2d(CNN=None, second=10, saveable=True, name='cnn', fig_idx=3119362):
  335. """Display a group of RGB or Greyscale CNN masks.
  336. Parameters
  337. ----------
  338. CNN : numpy.array
  339. The image. e.g: 64 5x5 RGB images can be (5, 5, 3, 64).
  340. second : int
  341. The display second(s) for the image(s), if saveable is False.
  342. saveable : boolean
  343. Save or plot the figure.
  344. name : str
  345. A name to save the image, if saveable is True.
  346. fig_idx : int
  347. The matplotlib figure index.
  348. Examples
  349. --------
  350. >>> tl.visualize.CNN2d(network.all_params[0].eval(), second=10, saveable=True, name='cnn1_mnist', fig_idx=2012)
  351. """
  352. import matplotlib.pyplot as plt
  353. # tl.logging.info(CNN.shape) # (5, 5, 3, 64)
  354. # exit()
  355. n_mask = CNN.shape[3]
  356. n_row = CNN.shape[0]
  357. n_col = CNN.shape[1]
  358. n_color = CNN.shape[2]
  359. row = int(np.sqrt(n_mask))
  360. col = int(np.ceil(n_mask / row))
  361. plt.ion() # active mode
  362. fig = plt.figure(fig_idx)
  363. count = 1
  364. for _ir in range(1, row + 1):
  365. for _ic in range(1, col + 1):
  366. if count > n_mask:
  367. break
  368. fig.add_subplot(col, row, count)
  369. # tl.logging.info(CNN[:,:,:,count-1].shape, n_row, n_col) # (5, 1, 32) 5 5
  370. # exit()
  371. # plt.imshow(
  372. # np.reshape(CNN[count-1,:,:,:], (n_row, n_col)),
  373. # cmap='gray', interpolation="nearest") # theano
  374. if n_color == 1:
  375. plt.imshow(np.reshape(CNN[:, :, :, count - 1], (n_row, n_col)), cmap='gray', interpolation="nearest")
  376. elif n_color == 3:
  377. plt.imshow(
  378. np.reshape(CNN[:, :, :, count - 1], (n_row, n_col, n_color)), cmap='gray', interpolation="nearest"
  379. )
  380. else:
  381. raise Exception("Unknown n_color")
  382. plt.gca().xaxis.set_major_locator(plt.NullLocator()) # distable tick
  383. plt.gca().yaxis.set_major_locator(plt.NullLocator())
  384. count = count + 1
  385. if saveable:
  386. plt.savefig(name + '.pdf', format='pdf')
  387. else:
  388. plt.draw()
  389. plt.pause(second)
  390. def images2d(images=None, second=10, saveable=True, name='images', dtype=None, fig_idx=3119362):
  391. """Display a group of RGB or Greyscale images.
  392. Parameters
  393. ----------
  394. images : numpy.array
  395. The images.
  396. second : int
  397. The display second(s) for the image(s), if saveable is False.
  398. saveable : boolean
  399. Save or plot the figure.
  400. name : str
  401. A name to save the image, if saveable is True.
  402. dtype : None or numpy data type
  403. The data type for displaying the images.
  404. fig_idx : int
  405. matplotlib figure index.
  406. Examples
  407. --------
  408. >>> X_train, y_train, X_test, y_test = tl.files.load_cifar10_dataset(shape=(-1, 32, 32, 3), plotable=False)
  409. >>> tl.visualize.images2d(X_train[0:100,:,:,:], second=10, saveable=False, name='cifar10', dtype=np.uint8, fig_idx=20212)
  410. """
  411. import matplotlib.pyplot as plt
  412. # tl.logging.info(images.shape) # (50000, 32, 32, 3)
  413. # exit()
  414. if dtype:
  415. images = np.asarray(images, dtype=dtype)
  416. n_mask = images.shape[0]
  417. n_row = images.shape[1]
  418. n_col = images.shape[2]
  419. n_color = images.shape[3]
  420. row = int(np.sqrt(n_mask))
  421. col = int(np.ceil(n_mask / row))
  422. plt.ion() # active mode
  423. fig = plt.figure(fig_idx)
  424. count = 1
  425. for _ir in range(1, row + 1):
  426. for _ic in range(1, col + 1):
  427. if count > n_mask:
  428. break
  429. fig.add_subplot(col, row, count)
  430. # tl.logging.info(images[:,:,:,count-1].shape, n_row, n_col) # (5, 1, 32) 5 5
  431. # plt.imshow(
  432. # np.reshape(images[count-1,:,:,:], (n_row, n_col)),
  433. # cmap='gray', interpolation="nearest") # theano
  434. if n_color == 1:
  435. plt.imshow(np.reshape(images[count - 1, :, :], (n_row, n_col)), cmap='gray', interpolation="nearest")
  436. # plt.title(name)
  437. elif n_color == 3:
  438. plt.imshow(images[count - 1, :, :], cmap='gray', interpolation="nearest")
  439. # plt.title(name)
  440. else:
  441. raise Exception("Unknown n_color")
  442. plt.gca().xaxis.set_major_locator(plt.NullLocator()) # distable tick
  443. plt.gca().yaxis.set_major_locator(plt.NullLocator())
  444. count = count + 1
  445. if saveable:
  446. plt.savefig(name + '.pdf', format='pdf')
  447. else:
  448. plt.draw()
  449. plt.pause(second)
  450. def tsne_embedding(embeddings, reverse_dictionary, plot_only=500, second=5, saveable=False, name='tsne', fig_idx=9862):
  451. """Visualize the embeddings by using t-SNE.
  452. Parameters
  453. ----------
  454. embeddings : numpy.array
  455. The embedding matrix.
  456. reverse_dictionary : dictionary
  457. id_to_word, mapping id to unique word.
  458. plot_only : int
  459. The number of examples to plot, choice the most common words.
  460. second : int
  461. The display second(s) for the image(s), if saveable is False.
  462. saveable : boolean
  463. Save or plot the figure.
  464. name : str
  465. A name to save the image, if saveable is True.
  466. fig_idx : int
  467. matplotlib figure index.
  468. Examples
  469. --------
  470. >>> see 'tutorial_word2vec_basic.py'
  471. >>> final_embeddings = normalized_embeddings.eval()
  472. >>> tl.visualize.tsne_embedding(final_embeddings, labels, reverse_dictionary,
  473. ... plot_only=500, second=5, saveable=False, name='tsne')
  474. """
  475. import matplotlib.pyplot as plt
  476. def plot_with_labels(low_dim_embs, labels, figsize=(18, 18), second=5, saveable=True, name='tsne', fig_idx=9862):
  477. if low_dim_embs.shape[0] < len(labels):
  478. raise AssertionError("More labels than embeddings")
  479. if saveable is False:
  480. plt.ion()
  481. plt.figure(fig_idx)
  482. plt.figure(figsize=figsize) # in inches
  483. for i, label in enumerate(labels):
  484. x, y = low_dim_embs[i, :]
  485. plt.scatter(x, y)
  486. plt.annotate(label, xy=(x, y), xytext=(5, 2), textcoords='offset points', ha='right', va='bottom')
  487. if saveable:
  488. plt.savefig(name + '.pdf', format='pdf')
  489. else:
  490. plt.draw()
  491. plt.pause(second)
  492. try:
  493. from sklearn.manifold import TSNE
  494. from six.moves import xrange
  495. tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000)
  496. # plot_only = 500
  497. low_dim_embs = tsne.fit_transform(embeddings[:plot_only, :])
  498. labels = [reverse_dictionary[i] for i in xrange(plot_only)]
  499. plot_with_labels(low_dim_embs, labels, second=second, saveable=saveable, name=name, fig_idx=fig_idx)
  500. except ImportError:
  501. _err = "Please install sklearn and matplotlib to visualize embeddings."
  502. tl.logging.error(_err)
  503. raise ImportError(_err)
  504. def draw_weights(W=None, second=10, saveable=True, shape=None, name='mnist', fig_idx=2396512):
  505. """Visualize every columns of the weight matrix to a group of Greyscale img.
  506. Parameters
  507. ----------
  508. W : numpy.array
  509. The weight matrix
  510. second : int
  511. The display second(s) for the image(s), if saveable is False.
  512. saveable : boolean
  513. Save or plot the figure.
  514. shape : a list with 2 int or None
  515. The shape of feature image, MNIST is [28, 80].
  516. name : a string
  517. A name to save the image, if saveable is True.
  518. fig_idx : int
  519. matplotlib figure index.
  520. Examples
  521. --------
  522. >>> tl.visualize.draw_weights(network.all_params[0].eval(), second=10, saveable=True, name='weight_of_1st_layer', fig_idx=2012)
  523. """
  524. if shape is None:
  525. shape = [28, 28]
  526. import matplotlib.pyplot as plt
  527. if saveable is False:
  528. plt.ion()
  529. fig = plt.figure(fig_idx) # show all feature images
  530. n_units = W.shape[1]
  531. num_r = int(np.sqrt(n_units)) # 每行显示的个数 若25个hidden unit -> 每行显示5个
  532. num_c = int(np.ceil(n_units / num_r))
  533. count = int(1)
  534. for _row in range(1, num_r + 1):
  535. for _col in range(1, num_c + 1):
  536. if count > n_units:
  537. break
  538. fig.add_subplot(num_r, num_c, count)
  539. # ------------------------------------------------------------
  540. # plt.imshow(np.reshape(W[:,count-1],(28,28)), cmap='gray')
  541. # ------------------------------------------------------------
  542. feature = W[:, count - 1] / np.sqrt((W[:, count - 1]**2).sum())
  543. # feature[feature<0.0001] = 0 # value threshold
  544. # if count == 1 or count == 2:
  545. # print(np.mean(feature))
  546. # if np.std(feature) < 0.03: # condition threshold
  547. # feature = np.zeros_like(feature)
  548. # if np.mean(feature) < -0.015: # condition threshold
  549. # feature = np.zeros_like(feature)
  550. plt.imshow(
  551. np.reshape(feature, (shape[0], shape[1])), cmap='gray', interpolation="nearest"
  552. ) # , vmin=np.min(feature), vmax=np.max(feature))
  553. # plt.title(name)
  554. # ------------------------------------------------------------
  555. # plt.imshow(np.reshape(W[:,count-1] ,(np.sqrt(size),np.sqrt(size))), cmap='gray', interpolation="nearest")
  556. plt.gca().xaxis.set_major_locator(plt.NullLocator()) # distable tick
  557. plt.gca().yaxis.set_major_locator(plt.NullLocator())
  558. count = count + 1
  559. if saveable:
  560. plt.savefig(name + '.pdf', format='pdf')
  561. else:
  562. plt.draw()
  563. plt.pause(second)
  564. W = draw_weights

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.