You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # 包含一些与网络无关的工具
  2. import glob
  3. import os
  4. import random
  5. import zipfile
  6. import cv2
  7. import torch
  8. def get_dataset_list(dataset_path):
  9. if not os.path.exists(dataset_path + '/dataset_list.txt'):
  10. all_list = glob.glob(dataset_path + '/labels' + '/*.png')
  11. random.shuffle(all_list)
  12. with open(dataset_path + '/dataset_list.txt', 'w', encoding='utf-8') as f:
  13. for line in all_list:
  14. f.write(os.path.basename(line.replace('\\', '/')) + '\n')
  15. print('已生成新的数据list')
  16. return all_list
  17. else:
  18. all_list = open(dataset_path + '/all_dataset_list.txt').readlines()
  19. return all_list
  20. def zip_dir(dir_path, zip_path):
  21. """
  22. 压缩文件
  23. :param dir_path: 目标文件夹路径
  24. :param zip_path: 压缩后的文件夹路径
  25. """
  26. ziper = zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED)
  27. for root, dirnames, filenames in os.walk(dir_path):
  28. file_path = root.replace(dir_path, '') # 去掉根路径,只对目标文件夹下的文件及文件夹进行压缩
  29. # 循环出一个个文件名
  30. for filename in filenames:
  31. ziper.write(os.path.join(root, filename), os.path.join(file_path, filename))
  32. ziper.close()
  33. def ncolors(num_colors):
  34. """
  35. 生成区别度较大的几种颜色
  36. copy: https://blog.csdn.net/choumin/article/details/90320297
  37. :param num_colors: 颜色数
  38. :return:
  39. """
  40. def get_n_hls_colors(num):
  41. import random
  42. hls_colors = []
  43. i = 0
  44. step = 360.0 / num
  45. while i < 360:
  46. h = i
  47. s = 90 + random.random() * 10
  48. li = 50 + random.random() * 10
  49. _hlsc = [h / 360.0, li / 100.0, s / 100.0]
  50. hls_colors.append(_hlsc)
  51. i += step
  52. return hls_colors
  53. import colorsys
  54. rgb_colors = []
  55. if num_colors < 1:
  56. return rgb_colors
  57. for hlsc in get_n_hls_colors(num_colors):
  58. _r, _g, _b = colorsys.hls_to_rgb(hlsc[0], hlsc[1], hlsc[2])
  59. r, g, b = [int(x * 255.0) for x in (_r, _g, _b)]
  60. rgb_colors.append([r, g, b])
  61. return rgb_colors
  62. def visual_label(dataset_path, n_classes):
  63. """
  64. 将标签可视化
  65. :param dataset_path: 地址
  66. :param n_classes: 类别数
  67. """
  68. label_path = os.path.join(dataset_path, 'test', 'labels').replace('\\', '/')
  69. label_image_list = glob.glob(label_path + '/*.png')
  70. label_image_list.sort()
  71. from torchvision import transforms
  72. trans_factory = transforms.ToPILImage()
  73. if not os.path.exists(dataset_path + '/visual_label'):
  74. os.mkdir(dataset_path + '/visual_label')
  75. for index in range(len(label_image_list)):
  76. label_image = cv2.imread(label_image_list[index], -1)
  77. name = os.path.basename(label_image_list[index])
  78. trans_factory(torch.from_numpy(label_image).float() / n_classes).save(
  79. dataset_path + '/visual_label/' + name,
  80. quality=95)
  81. def get_ckpt_path(version_nth: int, kth_fold: int):
  82. if version_nth is None:
  83. return None
  84. else:
  85. version_name = f'version_{version_nth + kth_fold}'
  86. checkpoints_path = './logs/default/' + version_name + '/checkpoints'
  87. ckpt_path = glob.glob(checkpoints_path + '/*.ckpt')
  88. return ckpt_path[0].replace('\\', '/')
  89. def rwxl():
  90. # 写
  91. # dataset_xl = xl.Workbook(write_only=True)
  92. # dataset_sh = dataset_xl.create_sheet('dataset', 0)
  93. # for row in range(self.x.shape[0]):
  94. # for col in range(self.x.shape[1]):
  95. # dataset_sh.cell(row + 1, col + 1).value = float(self.x[row, col])
  96. # dataset_sh.cell(row + 1, self.x.shape[1] + 1).value = float(self.y[row])
  97. # dataset_xl.save(dataset_path + '/dataset.xlsx')
  98. # dataset_xl.close()
  99. # 读
  100. # dataset_xl = xl.load_workbook(dataset_path + '/dataset_list.xlsx', read_only=True)
  101. # dataset_sh = dataset_xl.get_sheet_by_name('dataset_list')
  102. # temp = [[dataset_sh[row + 1][col].value for col in range(config['dim_in'] + 1)] for row in
  103. # range(config['dataset_len'])]
  104. # dataset_xl.close()
  105. pass
  106. if __name__ == "__main__":
  107. get_ckpt_path('version_0')

基于pytorch lightning的机器学习模板, 用于对机器学习算法进行训练, 验证, 测试等, 目前实现了神经网路, 深度学习, k折交叉, 自动保存训练信息等.

Contributors (1)