You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 8.2 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. from __future__ import (absolute_import, division,
  2. print_function, unicode_literals)
  3. import json
  4. import logging
  5. import os
  6. import shutil
  7. import tempfile
  8. from functools import wraps
  9. from hashlib import sha256
  10. import sys
  11. from io import open
  12. import boto3
  13. import requests
  14. from botocore.exceptions import ClientError
  15. from tqdm import tqdm
  16. try:
  17. from urllib.parse import urlparse
  18. except ImportError:
  19. from urlparse import urlparse
  20. try:
  21. from pathlib import Path
  22. PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
  23. Path.home() / '.pytorch_pretrained_bert'))
  24. except AttributeError:
  25. PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
  26. os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
  27. logger = logging.getLogger(__name__) # pylint: disable=invalid-name
  28. def url_to_filename(url, etag=None):
  29. """
  30. Convert `url` into a hashed filename in a repeatable way.
  31. If `etag` is specified, append its hash to the url's, delimited
  32. by a period.
  33. """
  34. url_bytes = url.encode('utf-8')
  35. url_hash = sha256(url_bytes)
  36. filename = url_hash.hexdigest()
  37. if etag:
  38. etag_bytes = etag.encode('utf-8')
  39. etag_hash = sha256(etag_bytes)
  40. filename += '.' + etag_hash.hexdigest()
  41. return filename
  42. def filename_to_url(filename, cache_dir=None):
  43. """
  44. Return the url and etag (which may be ``None``) stored for `filename`.
  45. Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
  46. """
  47. if cache_dir is None:
  48. cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
  49. if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
  50. cache_dir = str(cache_dir)
  51. cache_path = os.path.join(cache_dir, filename)
  52. if not os.path.exists(cache_path):
  53. raise EnvironmentError("file {} not found".format(cache_path))
  54. meta_path = cache_path + '.json'
  55. if not os.path.exists(meta_path):
  56. raise EnvironmentError("file {} not found".format(meta_path))
  57. with open(meta_path, encoding="utf-8") as meta_file:
  58. metadata = json.load(meta_file)
  59. url = metadata['url']
  60. etag = metadata['etag']
  61. return url, etag
  62. def cached_path(url_or_filename, cache_dir=None):
  63. """
  64. Given something that might be a URL (or might be a local path),
  65. determine which. If it's a URL, download the file and cache it, and
  66. return the path to the cached file. If it's already a local path,
  67. make sure the file exists and then return the path.
  68. """
  69. if cache_dir is None:
  70. cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
  71. if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
  72. url_or_filename = str(url_or_filename)
  73. if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
  74. cache_dir = str(cache_dir)
  75. parsed = urlparse(url_or_filename)
  76. if parsed.scheme in ('http', 'https', 's3'):
  77. # URL, so get it from the cache (downloading if necessary)
  78. return get_from_cache(url_or_filename, cache_dir)
  79. elif os.path.exists(url_or_filename):
  80. # File, and it exists.
  81. return url_or_filename
  82. elif parsed.scheme == '':
  83. # File, but it doesn't exist.
  84. raise EnvironmentError("file {} not found".format(url_or_filename))
  85. else:
  86. # Something unknown
  87. raise ValueError(
  88. "unable to parse {} as a URL or as a local path".format(url_or_filename))
  89. def split_s3_path(url):
  90. """Split a full s3 path into the bucket name and path."""
  91. parsed = urlparse(url)
  92. if not parsed.netloc or not parsed.path:
  93. raise ValueError("bad s3 path {}".format(url))
  94. bucket_name = parsed.netloc
  95. s3_path = parsed.path
  96. # Remove '/' at beginning of path.
  97. if s3_path.startswith("/"):
  98. s3_path = s3_path[1:]
  99. return bucket_name, s3_path
  100. def s3_request(func):
  101. """
  102. Wrapper function for s3 requests in order to create more helpful error
  103. messages.
  104. """
  105. @wraps(func)
  106. def wrapper(url, *args, **kwargs):
  107. try:
  108. return func(url, *args, **kwargs)
  109. except ClientError as exc:
  110. if int(exc.response["Error"]["Code"]) == 404:
  111. raise EnvironmentError("file {} not found".format(url))
  112. else:
  113. raise
  114. return wrapper
  115. @s3_request
  116. def s3_etag(url):
  117. """Check ETag on S3 object."""
  118. s3_resource = boto3.resource("s3")
  119. bucket_name, s3_path = split_s3_path(url)
  120. s3_object = s3_resource.Object(bucket_name, s3_path)
  121. return s3_object.e_tag
  122. @s3_request
  123. def s3_get(url, temp_file):
  124. """Pull a file directly from S3."""
  125. s3_resource = boto3.resource("s3")
  126. bucket_name, s3_path = split_s3_path(url)
  127. s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
  128. def http_get(url, temp_file):
  129. req = requests.get(url, stream=True)
  130. content_length = req.headers.get('Content-Length')
  131. total = int(content_length) if content_length is not None else None
  132. progress = tqdm(unit="B", total=total)
  133. for chunk in req.iter_content(chunk_size=1024):
  134. if chunk: # filter out keep-alive new chunks
  135. progress.update(len(chunk))
  136. temp_file.write(chunk)
  137. progress.close()
  138. def get_from_cache(url, cache_dir=None):
  139. """
  140. Given a URL, look for the corresponding dataset in the local cache.
  141. If it's not there, download it. Then return the path to the cached file.
  142. """
  143. if cache_dir is None:
  144. cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
  145. if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
  146. cache_dir = str(cache_dir)
  147. if not os.path.exists(cache_dir):
  148. os.makedirs(cache_dir)
  149. # Get eTag to add to filename, if it exists.
  150. if url.startswith("s3://"):
  151. etag = s3_etag(url)
  152. else:
  153. response = requests.head(url, allow_redirects=True)
  154. if response.status_code != 200:
  155. raise IOError("HEAD request failed for url {} with status code {}"
  156. .format(url, response.status_code))
  157. etag = response.headers.get("ETag")
  158. filename = url_to_filename(url, etag)
  159. # get cache path to put the file
  160. cache_path = os.path.join(cache_dir, filename)
  161. if not os.path.exists(cache_path):
  162. # Download to temporary file, then copy to cache dir once finished.
  163. # Otherwise you get corrupt cache entries if the download gets interrupted.
  164. with tempfile.NamedTemporaryFile() as temp_file:
  165. logger.info("%s not found in cache, downloading to %s",
  166. url, temp_file.name)
  167. # GET file object
  168. if url.startswith("s3://"):
  169. s3_get(url, temp_file)
  170. else:
  171. http_get(url, temp_file)
  172. # we are copying the file before closing it, so flush to avoid truncation
  173. temp_file.flush()
  174. # shutil.copyfileobj() starts at the current position, so go to the start
  175. temp_file.seek(0)
  176. logger.info("copying %s to cache at %s",
  177. temp_file.name, cache_path)
  178. with open(cache_path, 'wb') as cache_file:
  179. shutil.copyfileobj(temp_file, cache_file)
  180. logger.info("creating metadata file for %s", cache_path)
  181. meta = {'url': url, 'etag': etag}
  182. meta_path = cache_path + '.json'
  183. with open(meta_path, 'w', encoding="utf-8") as meta_file:
  184. json.dump(meta, meta_file)
  185. logger.info("removing temp file %s", temp_file.name)
  186. return cache_path
  187. def read_set_from_file(filename):
  188. '''
  189. Extract a de-duped collection (set) of text from a file.
  190. Expected file format is one item per line.
  191. '''
  192. collection = set()
  193. with open(filename, 'r', encoding='utf-8') as file_:
  194. for line in file_:
  195. collection.add(line.rstrip())
  196. return collection
  197. def get_file_extension(path, dot=True, lower=True):
  198. ext = os.path.splitext(path)[1]
  199. ext = ext if dot else ext[1:]
  200. return ext.lower() if lower else ext