You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

__init__.py 1.9 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. """
  2. Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
  3. Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
  4. """
  5. import importlib
  6. from jittor.dataset.dataset import Dataset
  7. from data.base_dataset import BaseDataset
  8. def find_dataset_using_name(dataset_name):
  9. # Given the option --dataset [datasetname],
  10. # the file "datasets/datasetname_dataset.py"
  11. # will be imported.
  12. dataset_filename = "data." + dataset_name + "_dataset"
  13. datasetlib = importlib.import_module(dataset_filename)
  14. # In the file, the class called DatasetNameDataset() will
  15. # be instantiated. It has to be a subclass of BaseDataset,
  16. # and it is case-insensitive.
  17. dataset = None
  18. target_dataset_name = dataset_name.replace('_', '') + 'dataset'
  19. for name, cls in datasetlib.__dict__.items():
  20. if name.lower() == target_dataset_name.lower() \
  21. and issubclass(cls, BaseDataset):
  22. dataset = cls
  23. if dataset is None:
  24. raise ValueError("In %s.py, there should be a subclass of BaseDataset "
  25. "with class name that matches %s in lowercase." %
  26. (dataset_filename, target_dataset_name))
  27. return dataset
  28. def get_option_setter(dataset_name):
  29. dataset_class = find_dataset_using_name(dataset_name)
  30. return dataset_class.modify_commandline_options
  31. def create_dataloader(opt):
  32. dataset = find_dataset_using_name(opt.dataset_mode)
  33. instance = dataset()
  34. instance.initialize(opt)
  35. print("dataset [%s] of size %d was created" %
  36. (type(instance).__name__, len(instance)))
  37. dataloader = instance.set_attrs(
  38. batch_size=opt.batchSize,
  39. shuffle=not opt.serial_batches,
  40. num_workers=int(opt.nThreads),
  41. drop_last=opt.isTrain
  42. )
  43. return dataloader

第三届计图人工智能挑战赛——风格及语义引导的风景图片生成赛道项目,由jittor计图框架实现