You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 4.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. import hashlib
  2. import os
  3. import re
  4. import shutil
  5. import sys
  6. import tempfile
  7. import torch
  8. try:
  9. from requests.utils import urlparse
  10. from requests import get as urlopen
  11. requests_available = True
  12. except ImportError:
  13. requests_available = False
  14. if sys.version_info[0] == 2:
  15. from urlparse import urlparse # noqa f811
  16. from urllib2 import urlopen # noqa f811
  17. else:
  18. from urllib.request import urlopen
  19. from urllib.parse import urlparse
  20. try:
  21. from tqdm.auto import tqdm
  22. except:
  23. from fastNLP.core.utils import _pseudo_tqdm as tqdm
  24. # matches bfd8deac from resnet18-bfd8deac.pth
  25. HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')
  26. def load_url(url, model_dir=None, map_location=None, progress=True):
  27. r"""Loads the Torch serialized object at the given URL.
  28. If the object is already present in `model_dir`, it's deserialized and
  29. returned. The filename part of the URL should follow the naming convention
  30. ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
  31. digits of the SHA256 hash of the contents of the file. The hash is used to
  32. ensure unique names and to verify the contents of the file.
  33. The default value of `model_dir` is ``$TORCH_HOME/models`` where
  34. ``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be
  35. overridden with the ``$TORCH_MODEL_ZOO`` environment variable.
  36. Args:
  37. url (string): URL of the object to download
  38. model_dir (string, optional): directory in which to save the object
  39. map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load)
  40. progress (bool, optional): whether or not to display a progress bar to stderr
  41. Example:
  42. # >>> state_dict = model_zoo.load_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
  43. """
  44. if model_dir is None:
  45. torch_home = os.path.expanduser(os.getenv('fastNLP_HOME', '~/.fastNLP'))
  46. model_dir = os.getenv('fastNLP_MODEL_ZOO', os.path.join(torch_home, 'models'))
  47. if not os.path.exists(model_dir):
  48. os.makedirs(model_dir)
  49. parts = urlparse(url)
  50. filename = os.path.basename(parts.path)
  51. cached_file = os.path.join(model_dir, filename)
  52. if not os.path.exists(cached_file):
  53. sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
  54. # hash_prefix = HASH_REGEX.search(filename).group(1)
  55. _download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
  56. return torch.load(cached_file, map_location=map_location)
  57. def _download_url_to_file(url, dst, hash_prefix, progress):
  58. if requests_available:
  59. u = urlopen(url, stream=True)
  60. file_size = int(u.headers["Content-Length"])
  61. u = u.raw
  62. else:
  63. u = urlopen(url)
  64. meta = u.info()
  65. if hasattr(meta, 'getheaders'):
  66. file_size = int(meta.getheaders("Content-Length")[0])
  67. else:
  68. file_size = int(meta.get_all("Content-Length")[0])
  69. f = tempfile.NamedTemporaryFile(delete=False)
  70. try:
  71. if hash_prefix is not None:
  72. sha256 = hashlib.sha256()
  73. with tqdm(total=file_size, disable=not progress) as pbar:
  74. while True:
  75. buffer = u.read(8192)
  76. if len(buffer) == 0:
  77. break
  78. f.write(buffer)
  79. if hash_prefix is not None:
  80. sha256.update(buffer)
  81. pbar.update(len(buffer))
  82. f.close()
  83. if hash_prefix is not None:
  84. digest = sha256.hexdigest()
  85. if digest[:len(hash_prefix)] != hash_prefix:
  86. raise RuntimeError('invalid hash value (expected "{}", got "{}")'
  87. .format(hash_prefix, digest))
  88. shutil.move(f.name, dst)
  89. finally:
  90. f.close()
  91. if os.path.exists(f.name):
  92. os.remove(f.name)
  93. if tqdm is None:
  94. # fake tqdm if it's not installed
  95. class tqdm(object):
  96. def __init__(self, total, disable=False):
  97. self.total = total
  98. self.disable = disable
  99. self.n = 0
  100. def update(self, n):
  101. if self.disable:
  102. return
  103. self.n += n
  104. sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total)))
  105. sys.stderr.flush()
  106. def __enter__(self):
  107. return self
  108. def __exit__(self, exc_type, exc_val, exc_tb):
  109. if self.disable:
  110. return
  111. sys.stderr.write('\n')