|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """
- hub for loading models:
- Users can load pre-trained models using mindspore.hub.load() API.
- """
- import os
- import re
- import shutil
- import tarfile
- import hashlib
- from urllib.request import urlretrieve
- import requests
- from bs4 import BeautifulSoup
-
- import mindspore
- import mindspore.nn as nn
- from mindspore import log as logger
- from mindspore.train.serialization import load_checkpoint, load_param_into_net
-
- DOWNLOAD_BASIC_URL = "http://download.mindspore.cn/model_zoo"
- OFFICIAL_NAME = "official"
- DEFAULT_CACHE_DIR = '.cache'
- MODEL_TARGET_CV = ['alexnet', 'fasterrcnn', 'googlenet', 'lenet', 'resnet', 'resnet50', 'ssd', 'vgg', 'yolo']
- MODEL_TARGET_NLP = ['bert', 'mass', 'transformer']
-
-
- def _packing_targz(output_filename, savepath=DEFAULT_CACHE_DIR):
- """
- Packing the input filename to filename.tar.gz in source dir.
- """
- try:
- with tarfile.open(output_filename, "w:gz") as tar:
- tar.add(savepath, arcname=os.path.basename(savepath))
- except Exception as e:
- raise OSError("Cannot tar file {} for - {}".format(output_filename, e))
-
-
- def _unpacking_targz(input_filename, savepath=DEFAULT_CACHE_DIR):
- """
- Unpacking the input filename to dirs.
- """
- try:
- t = tarfile.open(input_filename)
- t.extractall(path=savepath)
- except Exception as e:
- raise OSError("Cannot untar file {} for - {}".format(input_filename, e))
-
-
- def _remove_path_if_exists(path):
- if os.path.exists(path):
- if os.path.isfile(path):
- os.remove(path)
- else:
- shutil.rmtree(path)
-
-
- def _create_path_if_not_exists(path):
- if not os.path.exists(path):
- if os.path.isfile(path):
- os.remove(path)
- else:
- os.mkdir(path)
-
-
- def _get_weights_file(url, hash_md5=None, savepath=DEFAULT_CACHE_DIR):
- """
- get checkpoint weight from giving url.
-
- Args:
- url(string): checkpoint tar.gz url path.
- hash_md5(string): checkpoint file md5.
- savepath(string): checkpoint download save path.
-
- Returns:
- string.
- """
-
- def reporthook(a, b, c):
- percent = a * b * 100.0 / c
- show_str = ('[%%-%ds]' % 70) % (int(percent * 80) * '#')
- print("\rDownloading:", show_str, " %5.1f%%" % (percent), end="")
-
- def md5sum(file_name, hash_md5):
- fp = open(file_name, 'rb')
- content = fp.read()
- fp.close()
- m = hashlib.md5()
- m.update(content.encode('utf-8'))
- download_md5 = m.hexdigest()
- return download_md5 == hash_md5
-
- _remove_path_if_exists(os.path.realpath(savepath))
- _create_path_if_not_exists(os.path.realpath(savepath))
- ckpt_name = os.path.basename(url.split("/")[-1])
- # identify file exist or not
- file_path = os.path.join(savepath, ckpt_name)
- if os.path.isfile(file_path):
- if hash_md5 and md5sum(file_path, hash_md5):
- print('File already exists!')
- return file_path
-
- file_path_ = file_path[:-7] if ".tar.gz" in file_path else file_path
- _remove_path_if_exists(file_path_)
-
- # download the checkpoint file
- print('Downloading data from url {}'.format(url))
- try:
- urlretrieve(url, file_path, reporthook=reporthook)
- except HTTPError as e:
- raise Exception(e.code, e.msg, url)
- except URLError as e:
- raise Exception(e.errno, e.reason, url)
- print('\nDownload finished!')
-
- # untar file_path
- _unpacking_targz(file_path, os.path.realpath(savepath))
-
- filesize = os.path.getsize(file_path)
- # turn the file size to Mb format
- print('File size = %.2f Mb' % (filesize / 1024 / 1024))
- return file_path_
-
-
- def _get_url_paths(url, ext='.tar.gz'):
- response = requests.get(url)
- if response.ok:
- response_text = response.text
- else:
- return response.raise_for_status()
- soup = BeautifulSoup(response_text, 'html.parser')
- parent = [url + node.get('href') for node in soup.find_all('a')
- if node.get('href').endswith(ext)]
- return parent
-
-
- def _get_file_from_url(base_url, base_name):
- idx = 0
- urls = _get_url_paths(base_url + "/")
- files = [url.split('/')[-1] for url in urls]
- for i, name in enumerate(files):
- if re.match(base_name + '*', name) is not None:
- idx = i
- break
- return urls[idx]
-
-
- def load_weights(network, network_name=None, force_reload=True, **kwargs):
- r"""
- Load a model from mindspore, with pretrained weights.
-
- Args:
- network (Cell): Cell network.
- network_name (string, optional): Cell network name get from network. Default: None.
- force_reload (bool, optional): Whether to force a fresh download unconditionally. Default: False.
- kwargs (dict, optional): The corresponding kwargs for download for model.
-
- - device_target (str, optional): Runtime device target. Default: 'ascend'.
- - dataset (str, optional): Dataset to train the network. Default: 'cifar10'.
- - version (str, optional): MindSpore version to save the checkpoint. Default: Latest version.
-
- Example:
- >>> hub.load(network, network_name='lenet',
- **{'device_target': 'ascend', 'dataset':'mnist', 'version': '0.5.0'})
- """
- if not isinstance(network, nn.Cell):
- logger.error("Failed to combine the net and the parameters.")
- msg = ("Argument net should be a Cell, but got {}.".format(type(network)))
- raise TypeError(msg)
-
- if network_name is None:
- if hasattr(network, network_name):
- network_name = network.network_name
- else:
- msg = "Should input network name, but got None."
- raise TypeError(msg)
-
- device_target = kwargs['device_target'] if kwargs['device_target'] else 'ascend'
- dataset = kwargs['dataset'] if kwargs['dataset'] else 'imagenet'
- version = kwargs['version'] if kwargs['version'] else mindspore.version.__version__
-
- if network_name.split("_")[0] in MODEL_TARGET_CV:
- model_type = "cv"
- elif network_name.split("_")[0] in MODEL_TARGET_NLP:
- model_type = "nlp"
- else:
- raise ValueError("Unsupported network {} download checkpoint.".format(network_name.split("_")[0]))
-
- download_base_url = "/".join([DOWNLOAD_BASIC_URL,
- OFFICIAL_NAME, model_type, network_name])
- download_file_name = "_".join(
- [network_name, device_target, version, dataset, OFFICIAL_NAME])
- download_url = _get_file_from_url(download_base_url, download_file_name)
-
- if force_reload:
- ckpt_path = _get_weights_file(download_url, None, DEFAULT_CACHE_DIR)
- else:
- raise ValueError("Unsupported not force reload.")
-
- ckpt_file = os.path.join(ckpt_path, network_name + ".ckpt")
- param_dict = load_checkpoint(ckpt_file)
- load_param_into_net(network, param_dict)
|