You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

prep_data.py 4.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # Copyright 2020 Huawei Technologies Co., Ltd.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # import jsbeautifier
  15. import os
  16. import urllib
  17. import urllib.request
  18. def create_data_cache_dir():
  19. cwd = os.getcwd()
  20. target_directory = os.path.join(cwd, "data_cache")
  21. try:
  22. if not (os.path.exists(target_directory)):
  23. os.mkdir(target_directory)
  24. except OSError:
  25. print("Creation of the directory %s failed" % target_directory)
  26. return target_directory;
  27. def download_and_uncompress(files, source_url, target_directory, is_tar=False):
  28. for f in files:
  29. url = source_url + f
  30. target_file = os.path.join(target_directory, f)
  31. ##check if file already downloaded
  32. if not (os.path.exists(target_file) or os.path.exists(target_file[:-3])):
  33. urllib.request.urlretrieve(url, target_file)
  34. if is_tar:
  35. print("extracting from local tar file " + target_file)
  36. rc = os.system("tar -C " + target_directory + " -xvf " + target_file)
  37. else:
  38. print("unzipping " + target_file)
  39. rc = os.system("gunzip -f " + target_file)
  40. if rc != 0:
  41. print("Failed to uncompress ", target_file, " removing")
  42. os.system("rm " + target_file)
  43. ##exit with error so that build script will fail
  44. raise SystemError
  45. else:
  46. print("Using cached dataset at ", target_file)
  47. def download_mnist(target_directory=None):
  48. if target_directory == None:
  49. target_directory = create_data_cache_dir()
  50. ##create mnst directory
  51. target_directory = os.path.join(target_directory, "mnist")
  52. try:
  53. if not (os.path.exists(target_directory)):
  54. os.mkdir(target_directory)
  55. except OSError:
  56. print("Creation of the directory %s failed" % target_directory)
  57. MNIST_URL = "http://yann.lecun.com/exdb/mnist/"
  58. files = ['train-images-idx3-ubyte.gz',
  59. 'train-labels-idx1-ubyte.gz',
  60. 't10k-images-idx3-ubyte.gz',
  61. 't10k-labels-idx1-ubyte.gz']
  62. download_and_uncompress(files, MNIST_URL, target_directory, is_tar=False)
  63. return target_directory, os.path.join(target_directory, "datasetSchema.json")
  64. CIFAR_URL = "https://www.cs.toronto.edu/~kriz/"
  65. def download_cifar(target_directory, files, directory_from_tar):
  66. if target_directory == None:
  67. target_directory = create_data_cache_dir()
  68. download_and_uncompress([files], CIFAR_URL, target_directory, is_tar=True)
  69. ##if target dir was specify move data from directory created by tar
  70. ##and put data into target dir
  71. if target_directory != None:
  72. tar_dir_full_path = os.path.join(target_directory, directory_from_tar)
  73. all_files = os.path.join(tar_dir_full_path, "*")
  74. cmd = "mv " + all_files + " " + target_directory
  75. if os.path.exists(tar_dir_full_path):
  76. print("copy files back to target_directory")
  77. print("Executing: ", cmd)
  78. rc1 = os.system(cmd)
  79. rc2 = os.system("rm -r " + tar_dir_full_path)
  80. if rc1 != 0 or rc2 != 0:
  81. print("error when running command: ", cmd)
  82. download_file = os.path.join(target_directory, files)
  83. print("removing " + download_file)
  84. os.system("rm " + download_file)
  85. ##exit with error so that build script will fail
  86. raise SystemError
  87. ##change target directory to directory after tar
  88. return os.path.join(target_directory, directory_from_tar)
  89. def download_cifar10(target_directory=None):
  90. return download_cifar(target_directory, "cifar-10-binary.tar.gz", "cifar-10-batches-bin")
  91. def download_cifar100(target_directory=None):
  92. return download_cifar(target_directory, "cifar-100-binary.tar.gz", "cifar-100-binary")
  93. def download_all_for_test(cwd):
  94. download_mnist(os.path.join(cwd, "testMnistData"))
  95. ##Download all datasets to existing test directories
  96. if __name__ == "__main__":
  97. download_all_for_test(os.getcwd())