GitOrigin-RevId: a48e107623
tags/v1.7.1.m1
| @@ -106,9 +106,7 @@ class CIFAR10(VisionDataset): | |||||
| def download(self): | def download(self): | ||||
| url = self.url_path + self.raw_file_name | url = self.url_path + self.raw_file_name | ||||
| load_raw_data_from_url( | |||||
| url, self.raw_file_name, self.raw_file_md5, self.root, self.timeout | |||||
| ) | |||||
| load_raw_data_from_url(url, self.raw_file_name, self.raw_file_md5, self.root) | |||||
| self.process() | self.process() | ||||
| def untar(self, file_path, dirs): | def untar(self, file_path, dirs): | ||||
| @@ -118,7 +118,7 @@ class MNIST(VisionDataset): | |||||
| def download(self): | def download(self): | ||||
| for file_name, md5 in zip(self.raw_file_name, self.raw_file_md5): | for file_name, md5 in zip(self.raw_file_name, self.raw_file_md5): | ||||
| url = self.url_path + file_name | url = self.url_path + file_name | ||||
| load_raw_data_from_url(url, file_name, md5, self.root, self.timeout) | |||||
| load_raw_data_from_url(url, file_name, md5, self.root) | |||||
| def process(self, train): | def process(self, train): | ||||
| # load raw files and transform them into meta data and datasets Tuple(np.array) | # load raw files and transform them into meta data and datasets Tuple(np.array) | ||||
| @@ -27,9 +27,7 @@ def _default_dataset_root(): | |||||
| return default_dataset_root | return default_dataset_root | ||||
| def load_raw_data_from_url( | |||||
| url: str, filename: str, target_md5: str, raw_data_dir: str, timeout: int | |||||
| ): | |||||
| def load_raw_data_from_url(url: str, filename: str, target_md5: str, raw_data_dir: str): | |||||
| cached_file = os.path.join(raw_data_dir, filename) | cached_file = os.path.join(raw_data_dir, filename) | ||||
| logger.debug( | logger.debug( | ||||
| "load_raw_data_from_url: downloading to or using cached %s ...", cached_file | "load_raw_data_from_url: downloading to or using cached %s ...", cached_file | ||||
| @@ -41,7 +39,7 @@ def load_raw_data_from_url( | |||||
| " File may be downloaded multiple times. We recommend\n" | " File may be downloaded multiple times. We recommend\n" | ||||
| " users to download in single process first." | " users to download in single process first." | ||||
| ) | ) | ||||
| md5 = download_from_url(url, cached_file, http_read_timeout=timeout) | |||||
| md5 = download_from_url(url, cached_file) | |||||
| else: | else: | ||||
| md5 = calculate_md5(cached_file) | md5 = calculate_md5(cached_file) | ||||
| if target_md5 == md5: | if target_md5 == md5: | ||||
| @@ -25,7 +25,6 @@ from .const import ( | |||||
| DEFAULT_PROTOCOL, | DEFAULT_PROTOCOL, | ||||
| ENV_MGE_HOME, | ENV_MGE_HOME, | ||||
| ENV_XDG_CACHE_HOME, | ENV_XDG_CACHE_HOME, | ||||
| HTTP_READ_TIMEOUT, | |||||
| HUBCONF, | HUBCONF, | ||||
| HUBDEPENDENCY, | HUBDEPENDENCY, | ||||
| ) | ) | ||||
| @@ -263,14 +262,14 @@ def load_serialized_obj_from_url(url: str, model_dir=None) -> Any: | |||||
| " File may be downloaded multiple times. We recommend\n" | " File may be downloaded multiple times. We recommend\n" | ||||
| " users to download in single process first." | " users to download in single process first." | ||||
| ) | ) | ||||
| download_from_url(url, cached_file, HTTP_READ_TIMEOUT) | |||||
| download_from_url(url, cached_file) | |||||
| state_dict = _mge_load_serialized(cached_file) | state_dict = _mge_load_serialized(cached_file) | ||||
| return state_dict | return state_dict | ||||
| class pretrained: | class pretrained: | ||||
| r"""Decorator which helps to download pretrained weights from the given url. | |||||
| r"""Decorator which helps to download pretrained weights from the given url. Including fs, s3, http(s). | |||||
| For example, we can decorate a resnet18 function as follows | For example, we can decorate a resnet18 function as follows | ||||
| @@ -12,6 +12,7 @@ import shutil | |||||
| from tempfile import NamedTemporaryFile | from tempfile import NamedTemporaryFile | ||||
| import requests | import requests | ||||
| from megfile import smart_copy, smart_getmd5, smart_getsize | |||||
| from tqdm import tqdm | from tqdm import tqdm | ||||
| from ..logger import get_logger | from ..logger import get_logger | ||||
| @@ -26,41 +27,21 @@ class HTTPDownloadError(BaseException): | |||||
| r"""The class that represents http request error.""" | r"""The class that represents http request error.""" | ||||
| def download_from_url(url: str, dst: str, http_read_timeout=120): | |||||
| class Bar: | |||||
| def __init__(self, total=100): | |||||
| self._bar = tqdm(total=total, unit="iB", unit_scale=True, ncols=80) | |||||
| def __call__(self, bytes_num): | |||||
| self._bar.update(bytes_num) | |||||
| def download_from_url(url: str, dst: str): | |||||
| r"""Downloads file from given url to ``dst``. | r"""Downloads file from given url to ``dst``. | ||||
| Args: | Args: | ||||
| url: source URL. | url: source URL. | ||||
| dst: saving path. | dst: saving path. | ||||
| http_read_timeout: how many seconds to wait for data before giving up. | |||||
| """ | """ | ||||
| dst = os.path.expanduser(dst) | dst = os.path.expanduser(dst) | ||||
| dst_dir = os.path.dirname(dst) | |||||
| resp = requests.get( | |||||
| url, timeout=(HTTP_CONNECTION_TIMEOUT, http_read_timeout), stream=True | |||||
| ) | |||||
| if resp.status_code != 200: | |||||
| raise HTTPDownloadError("An error occured when downloading from {}".format(url)) | |||||
| md5 = hashlib.md5() | |||||
| total_size = int(resp.headers.get("Content-Length", 0)) | |||||
| bar = tqdm( | |||||
| total=total_size, unit="iB", unit_scale=True, ncols=80 | |||||
| ) # pylint: disable=blacklisted-name | |||||
| try: | |||||
| with NamedTemporaryFile("w+b", delete=False, suffix=".tmp", dir=dst_dir) as f: | |||||
| logger.info("Download file to temp file %s", f.name) | |||||
| for chunk in resp.iter_content(CHUNK_SIZE): | |||||
| if not chunk: | |||||
| break | |||||
| bar.update(len(chunk)) | |||||
| f.write(chunk) | |||||
| md5.update(chunk) | |||||
| bar.close() | |||||
| shutil.move(f.name, dst) | |||||
| finally: | |||||
| # ensure tmp file is removed | |||||
| if os.path.exists(f.name): | |||||
| os.remove(f.name) | |||||
| return md5.hexdigest() | |||||
| smart_copy(url, dst, callback=Bar(total=smart_getsize(url))) | |||||
| return smart_getmd5(dst) | |||||
| @@ -8,3 +8,4 @@ redispy | |||||
| deprecated | deprecated | ||||
| mprop | mprop | ||||
| wheel | wheel | ||||
| megfile>=0.0.10 | |||||