|
|
|
@@ -0,0 +1,212 @@ |
|
|
|
# Copyright 2020 Huawei Technologies Co., Ltd |
|
|
|
# |
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
# you may not use this file except in compliance with the License. |
|
|
|
# You may obtain a copy of the License at |
|
|
|
# |
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
# |
|
|
|
# Unless required by applicable law or agreed to in writing, software |
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
# See the License for the specific language governing permissions and |
|
|
|
# limitations under the License. |
|
|
|
# ============================================================================ |
|
|
|
""" |
|
|
|
hub for loading models: |
|
|
|
Users can load pre-trained models using mindspore.hub.load() API. |
|
|
|
""" |
|
|
|
import os |
|
|
|
import re |
|
|
|
import shutil |
|
|
|
import tarfile |
|
|
|
import hashlib |
|
|
|
from urllib.request import urlretrieve |
|
|
|
import requests |
|
|
|
from bs4 import BeautifulSoup |
|
|
|
|
|
|
|
import mindspore |
|
|
|
import mindspore.nn as nn |
|
|
|
from mindspore import log as logger |
|
|
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net |
|
|
|
|
|
|
|
DOWNLOAD_BASIC_URL = "http://download.mindspore.cn/model_zoo" |
|
|
|
OFFICIAL_NAME = "official" |
|
|
|
DEFAULT_CACHE_DIR = '~/.cache' |
|
|
|
MODEL_TARGET_CV = ['alexnet', 'fasterrcnn', 'googlenet', |
|
|
|
'lenet', 'resnet', 'ssd', 'vgg', 'yolo'] |
|
|
|
MODEL_TARGET_NLP = ['bert', 'mass', 'transformer'] |
|
|
|
|
|
|
|
|
|
|
|
def _packing_targz(output_filename, savepath="./"): |
|
|
|
""" |
|
|
|
Packing the input filename to filename.tar.gz in source dir. |
|
|
|
""" |
|
|
|
try: |
|
|
|
with tarfile.open(output_filename, "w:gz") as tar: |
|
|
|
tar.add(savepath, arcname=os.path.basename(savepath)) |
|
|
|
except Exception as e: |
|
|
|
raise OSError("Cannot tar file {} for - {}".format(output_filename, e)) |
|
|
|
|
|
|
|
|
|
|
|
def _unpacking_targz(input_filename, savepath="./"): |
|
|
|
""" |
|
|
|
Unpacking the input filename to dirs. |
|
|
|
""" |
|
|
|
try: |
|
|
|
t = tarfile.open(input_filename) |
|
|
|
t.extractall(path=savepath) |
|
|
|
except Exception as e: |
|
|
|
raise OSError("Cannot untar file {} for - {}".format(input_filename, e)) |
|
|
|
|
|
|
|
|
|
|
|
def _remove_path_if_exists(path): |
|
|
|
if os.path.exists(path): |
|
|
|
if os.path.isfile(path): |
|
|
|
os.remove(path) |
|
|
|
else: |
|
|
|
shutil.rmtree(path) |
|
|
|
|
|
|
|
|
|
|
|
def _create_path_if_not_exists(path): |
|
|
|
if os.path.exists(path): |
|
|
|
if os.path.isfile(path): |
|
|
|
os.remove(path) |
|
|
|
else: |
|
|
|
os.mkdir(path) |
|
|
|
|
|
|
|
|
|
|
|
def _get_weights_file(url, hash_md5=None, savepath='./'): |
|
|
|
""" |
|
|
|
get checkpoint weight from giving url. |
|
|
|
|
|
|
|
Args: |
|
|
|
url(string): checkpoint tar.gz url path. |
|
|
|
hash_md5(string): checkpoint file md5. |
|
|
|
savepath(string): checkpoint download save path. |
|
|
|
|
|
|
|
Returns: |
|
|
|
string. |
|
|
|
""" |
|
|
|
|
|
|
|
def reporthook(a, b, c): |
|
|
|
percent = a * b * 100.0 / c |
|
|
|
show_str = ('[%%-%ds]' % 70) % (int(percent * 80) * '#') |
|
|
|
print("\rDownloading:", show_str, " %5.1f%%" % (percent), end="") |
|
|
|
|
|
|
|
def md5sum(file_name, hash_md5): |
|
|
|
fp = open(file_name, 'rb') |
|
|
|
content = fp.read() |
|
|
|
fp.close() |
|
|
|
m = hashlib.md5() |
|
|
|
m.update(content.encode('utf-8')) |
|
|
|
download_md5 = m.hexdigest() |
|
|
|
return download_md5 == hash_md5 |
|
|
|
|
|
|
|
_create_path_if_not_exists(savepath) |
|
|
|
ckpt_name = os.path.basename(url.split("/")[-1]) |
|
|
|
# identify file exist or not |
|
|
|
file_path = os.path.join(savepath, ckpt_name) |
|
|
|
if os.path.isfile(file_path): |
|
|
|
if hash_md5 and md5sum(file_path, hash_md5): |
|
|
|
print('File already exists!') |
|
|
|
return file_path |
|
|
|
|
|
|
|
file_path = file_path[:-7] if ".tar.gz" in file_path else file_path |
|
|
|
_remove_path_if_exists(file_path) |
|
|
|
|
|
|
|
# download the checkpoint file |
|
|
|
print('Downloading data from url {}'.format(url)) |
|
|
|
try: |
|
|
|
urlretrieve(url, file_path, reporthook=reporthook) |
|
|
|
except HTTPError as e: |
|
|
|
raise Exception(e.code, e.msg, url) |
|
|
|
except URLError as e: |
|
|
|
raise Exception(e.errno, e.reason, url) |
|
|
|
print('\nDownload finished!') |
|
|
|
|
|
|
|
# untar file_path |
|
|
|
_unpacking_targz(file_path) |
|
|
|
|
|
|
|
# # get the file size |
|
|
|
file_path = os.path.join(savepath, ckpt_name) |
|
|
|
filesize = os.path.getsize(file_path) |
|
|
|
# turn the file size to Mb format |
|
|
|
print('File size = %.2f Mb' % (filesize / 1024 / 1024)) |
|
|
|
return file_path |
|
|
|
|
|
|
|
|
|
|
|
def _get_url_paths(url, ext='.tar.gz'): |
|
|
|
response = requests.get(url) |
|
|
|
if response.ok: |
|
|
|
response_text = response.text |
|
|
|
else: |
|
|
|
return response.raise_for_status() |
|
|
|
soup = BeautifulSoup(response_text, 'html.parser') |
|
|
|
parent = [url + node.get('href') for node in soup.find_all('a') |
|
|
|
if node.get('href').endswith(ext)] |
|
|
|
return parent |
|
|
|
|
|
|
|
|
|
|
|
def _get_file_from_url(base_url, base_name): |
|
|
|
idx = 0 |
|
|
|
urls = _get_url_paths(base_url) |
|
|
|
files = [url.split('/')[-1] for url in urls] |
|
|
|
for i, name in enumerate(files): |
|
|
|
if re.match(base_name + '*', name) is not None: |
|
|
|
idx = i |
|
|
|
break |
|
|
|
return urls[idx] |
|
|
|
|
|
|
|
|
|
|
|
def load_weights(network, network_name=None, force_reload=True, **kwargs): |
|
|
|
r""" |
|
|
|
Load a model from mindspore, with pretrained weights. |
|
|
|
|
|
|
|
Args: |
|
|
|
network (Cell): Cell network. |
|
|
|
network_name (string, optional): Cell network name get from network. Default: None. |
|
|
|
force_reload (bool, optional): Whether to force a fresh download unconditionally. Default: False. |
|
|
|
**kwargs (optional): The corresponding kwargs for download for model. |
|
|
|
device_target (string, optional): Runtime device target. Default: 'ascend'. |
|
|
|
dataset (string, optional): Dataset to train the network. Default: 'cifar10'. |
|
|
|
|
|
|
|
Example: |
|
|
|
>>> mindspore.hub.load(network, network_name='lenet', |
|
|
|
**{'device_target': 'ascend', 'dataset':'cifar10', 'version': 'beta0.5'}) |
|
|
|
""" |
|
|
|
if not isinstance(network, nn.Cell): |
|
|
|
logger.error("Failed to combine the net and the parameters.") |
|
|
|
msg = ("Argument net should be a Cell, but got {}.".format(type(network))) |
|
|
|
raise TypeError(msg) |
|
|
|
|
|
|
|
if network_name is None: |
|
|
|
if hasattr(network, network_name): |
|
|
|
network_name = network.network_name |
|
|
|
else: |
|
|
|
msg = "Should input network name, but got None." |
|
|
|
raise TypeError(msg) |
|
|
|
|
|
|
|
device_target = kwargs['device_target'] if kwargs['device_target'] else 'ascend' |
|
|
|
dataset = kwargs['dataset'] if kwargs['dataset'] else 'imagenet' |
|
|
|
version = kwargs['version'] if kwargs['version'] else mindspore.version.__version__ |
|
|
|
|
|
|
|
if network_name.split("_")[0] in MODEL_TARGET_CV: |
|
|
|
model_type = "cv" |
|
|
|
elif network_name.split("_")[0] in MODEL_TARGET_NLP: |
|
|
|
model_type = "nlp" |
|
|
|
|
|
|
|
download_base_url = "/".join([DOWNLOAD_BASIC_URL, |
|
|
|
OFFICIAL_NAME, model_type]) |
|
|
|
download_file_name = "_".join( |
|
|
|
[network_name, device_target, version, dataset, OFFICIAL_NAME]) |
|
|
|
download_url = _get_file_from_url(download_base_url, download_file_name) |
|
|
|
|
|
|
|
if force_reload: |
|
|
|
ckpt_path = _get_weights_file(download_url, None, DEFAULT_CACHE_DIR) |
|
|
|
else: |
|
|
|
raise ValueError("Unsupported not force reload.") |
|
|
|
|
|
|
|
ckpt_file = os.path.join(ckpt_path, network_name + ".ckpt") |
|
|
|
param_dict = load_checkpoint(ckpt_file) |
|
|
|
load_param_into_net(network, param_dict) |