|
- # Copyright 2020 Huawei Technologies Co., Ltd.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # import jsbeautifier
-
- import os
- import urllib
- import urllib.request
-
-
- def create_data_cache_dir():
- cwd = os.getcwd()
- target_directory = os.path.join(cwd, "data_cache")
- try:
- if not (os.path.exists(target_directory)):
- os.mkdir(target_directory)
- except OSError:
- print("Creation of the directory %s failed" % target_directory)
- return target_directory;
-
-
- def download_and_uncompress(files, source_url, target_directory, is_tar=False):
- for f in files:
- url = source_url + f
- target_file = os.path.join(target_directory, f)
-
- ##check if file already downloaded
- if not (os.path.exists(target_file) or os.path.exists(target_file[:-3])):
- urllib.request.urlretrieve(url, target_file)
- if is_tar:
- print("extracting from local tar file " + target_file)
- rc = os.system("tar -C " + target_directory + " -xvf " + target_file)
- else:
- print("unzipping " + target_file)
- rc = os.system("gunzip -f " + target_file)
- if rc != 0:
- print("Failed to uncompress ", target_file, " removing")
- os.system("rm " + target_file)
- ##exit with error so that build script will fail
- raise SystemError
- else:
- print("Using cached dataset at ", target_file)
-
-
- def download_mnist(target_directory=None):
- if target_directory == None:
- target_directory = create_data_cache_dir()
-
- ##create mnst directory
- target_directory = os.path.join(target_directory, "mnist")
- try:
- if not (os.path.exists(target_directory)):
- os.mkdir(target_directory)
- except OSError:
- print("Creation of the directory %s failed" % target_directory)
-
- MNIST_URL = "http://yann.lecun.com/exdb/mnist/"
- files = ['train-images-idx3-ubyte.gz',
- 'train-labels-idx1-ubyte.gz',
- 't10k-images-idx3-ubyte.gz',
- 't10k-labels-idx1-ubyte.gz']
- download_and_uncompress(files, MNIST_URL, target_directory, is_tar=False)
-
- return target_directory, os.path.join(target_directory, "datasetSchema.json")
-
-
- CIFAR_URL = "https://www.cs.toronto.edu/~kriz/"
-
-
- def download_cifar(target_directory, files, directory_from_tar):
- if target_directory == None:
- target_directory = create_data_cache_dir()
-
- download_and_uncompress([files], CIFAR_URL, target_directory, is_tar=True)
-
- ##if target dir was specify move data from directory created by tar
- ##and put data into target dir
- if target_directory != None:
- tar_dir_full_path = os.path.join(target_directory, directory_from_tar)
- all_files = os.path.join(tar_dir_full_path, "*")
- cmd = "mv " + all_files + " " + target_directory
- if os.path.exists(tar_dir_full_path):
- print("copy files back to target_directory")
- print("Executing: ", cmd)
- rc1 = os.system(cmd)
- rc2 = os.system("rm -r " + tar_dir_full_path)
- if rc1 != 0 or rc2 != 0:
- print("error when running command: ", cmd)
- download_file = os.path.join(target_directory, files)
- print("removing " + download_file)
- os.system("rm " + download_file)
-
- ##exit with error so that build script will fail
- raise SystemError
-
- ##change target directory to directory after tar
- return os.path.join(target_directory, directory_from_tar)
-
-
- def download_cifar10(target_directory=None):
- return download_cifar(target_directory, "cifar-10-binary.tar.gz", "cifar-10-batches-bin")
-
-
- def download_cifar100(target_directory=None):
- return download_cifar(target_directory, "cifar-100-binary.tar.gz", "cifar-100-binary")
-
-
- def download_all_for_test(cwd):
- download_mnist(os.path.join(cwd, "testMnistData"))
-
-
- ##Download all datasets to existing test directories
- if __name__ == "__main__":
- download_all_for_test(os.getcwd())
|