You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base_dataset.py 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. """
  2. Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
  3. Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
  4. """
  5. from jittor.dataset.dataset import Dataset
  6. import jittor.transform as transform
  7. from PIL import Image
  8. import numpy as np
  9. import random
  10. class BaseDataset(Dataset):
  11. def __init__(self):
  12. super(BaseDataset, self).__init__()
  13. @staticmethod
  14. def modify_commandline_options(parser, is_train):
  15. return parser
  16. def initialize(self, opt):
  17. pass
  18. def get_params(opt, size):
  19. w, h = size
  20. new_h = h
  21. new_w = w
  22. if opt.preprocess_mode == 'resize_and_crop':
  23. new_h = new_w = opt.load_size
  24. elif opt.preprocess_mode == 'scale_width_and_crop':
  25. new_w = opt.load_size
  26. new_h = opt.load_size * h // w
  27. elif opt.preprocess_mode == 'scale_shortside_and_crop':
  28. ss, ls = min(w, h), max(w, h) # shortside and longside
  29. width_is_shorter = w == ss
  30. ls = int(opt.load_size * ls / ss)
  31. new_w, new_h = (ss, ls) if width_is_shorter else (ls, ss)
  32. x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
  33. y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
  34. flip = random.random() > 0.5
  35. return {'crop_pos': (x, y), 'flip': flip}
  36. def get_transform(opt, params, method=Image.BICUBIC, normalize=True, toTensor=True):
  37. transform_list = []
  38. if 'resize' in opt.preprocess_mode:
  39. osize = [opt.load_size, opt.load_size]
  40. transform_list.append(transform.Resize(osize, mode=method))
  41. elif 'scale_width' in opt.preprocess_mode:
  42. transform_list.append(transform.Lambda(
  43. lambda img: __scale_width(img, opt.load_size, method)))
  44. elif 'scale_shortside' in opt.preprocess_mode:
  45. transform_list.append(transform.Lambda(
  46. lambda img: __scale_shortside(img, opt.load_size, method)))
  47. if 'crop' in opt.preprocess_mode:
  48. transform_list.append(transform.Lambda(
  49. lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
  50. if opt.preprocess_mode == 'none':
  51. base = 32
  52. transform_list.append(transform.Lambda(
  53. lambda img: __make_power_2(img, base, method)))
  54. if opt.preprocess_mode == 'fixed':
  55. w = opt.crop_size
  56. h = round(opt.crop_size / opt.aspect_ratio)
  57. transform_list.append(transform.Lambda(
  58. lambda img: __resize(img, w, h, method)))
  59. if opt.isTrain and not opt.no_flip:
  60. transform_list.append(transform.Lambda(
  61. lambda img: __flip(img, params['flip'])))
  62. if toTensor:
  63. transform_list += [transform.ToTensor()]
  64. if normalize:
  65. transform_list += [transform.ImageNormalize(mean=[0.5], std=[0.5])]
  66. return transform.Compose(transform_list)
  67. def normalize():
  68. return transform.ImageNormalize(mean=[0.5], std=[0.5])
  69. def __resize(img, w, h, method=Image.BICUBIC):
  70. return img.resize((w, h), method)
  71. def __make_power_2(img, base, method=Image.BICUBIC):
  72. ow, oh = img.size
  73. h = int(round(oh / base) * base)
  74. w = int(round(ow / base) * base)
  75. if (h == oh) and (w == ow):
  76. return img
  77. return img.resize((w, h), method)
  78. def __scale_width(img, target_width, method=Image.BICUBIC):
  79. ow, oh = img.size
  80. if (ow == target_width):
  81. return img
  82. w = target_width
  83. h = int(target_width * oh / ow)
  84. return img.resize((w, h), method)
  85. def __scale_shortside(img, target_width, method=Image.BICUBIC):
  86. ow, oh = img.size
  87. ss, ls = min(ow, oh), max(ow, oh) # shortside and longside
  88. width_is_shorter = ow == ss
  89. if (ss == target_width):
  90. return img
  91. ls = int(target_width * ls / ss)
  92. nw, nh = (ss, ls) if width_is_shorter else (ls, ss)
  93. return img.resize((nw, nh), method)
  94. def __crop(img, pos, size):
  95. ow, oh = img.size
  96. x1, y1 = pos
  97. tw = th = size
  98. return img.crop((x1, y1, x1 + tw, y1 + th))
  99. def __flip(img, flip):
  100. if flip:
  101. return img.transpose(Image.FLIP_LEFT_RIGHT)
  102. return img

第三届计图人工智能挑战赛——风格及语义引导的风景图片生成赛道项目,由jittor计图框架实现