From 22aae8c4b7cdae64d3b4b7ca5d035083790b286d Mon Sep 17 00:00:00 2001 From: huzhifeng Date: Mon, 14 Sep 2020 15:29:58 +0800 Subject: [PATCH] remove hub from mindspore master --- hub/docs/.gitkeep | 0 hub/images/.gitkeep | 0 hub/scripts/.gitkeep | 0 mindspore/hub.py | 214 ------------------------------------------- 4 files changed, 214 deletions(-) delete mode 100644 hub/docs/.gitkeep delete mode 100644 hub/images/.gitkeep delete mode 100644 hub/scripts/.gitkeep delete mode 100644 mindspore/hub.py diff --git a/hub/docs/.gitkeep b/hub/docs/.gitkeep deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/hub/images/.gitkeep b/hub/images/.gitkeep deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/hub/scripts/.gitkeep b/hub/scripts/.gitkeep deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/mindspore/hub.py b/mindspore/hub.py deleted file mode 100644 index 52a1e7754f..0000000000 --- a/mindspore/hub.py +++ /dev/null @@ -1,214 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -hub for loading models: -Users can load pre-trained models using mindspore.hub.load() API. -""" -import os -import re -import shutil -import tarfile -import hashlib -from urllib.request import urlretrieve -import requests -from bs4 import BeautifulSoup - -import mindspore -import mindspore.nn as nn -from mindspore import log as logger -from mindspore.train.serialization import load_checkpoint, load_param_into_net - -DOWNLOAD_BASIC_URL = "http://download.mindspore.cn/model_zoo" -OFFICIAL_NAME = "official" -DEFAULT_CACHE_DIR = '.cache' -MODEL_TARGET_CV = ['alexnet', 'fasterrcnn', 'googlenet', 'lenet', 'resnet', 'resnet50', 'ssd', 'vgg', 'yolo'] -MODEL_TARGET_NLP = ['bert', 'mass', 'transformer'] - - -def _packing_targz(output_filename, savepath=DEFAULT_CACHE_DIR): - """ - Packing the input filename to filename.tar.gz in source dir. - """ - try: - with tarfile.open(output_filename, "w:gz") as tar: - tar.add(savepath, arcname=os.path.basename(savepath)) - except Exception as e: - raise OSError("Cannot tar file {} for - {}".format(output_filename, e)) - - -def _unpacking_targz(input_filename, savepath=DEFAULT_CACHE_DIR): - """ - Unpacking the input filename to dirs. - """ - try: - t = tarfile.open(input_filename) - t.extractall(path=savepath) - except Exception as e: - raise OSError("Cannot untar file {} for - {}".format(input_filename, e)) - - -def _remove_path_if_exists(path): - if os.path.exists(path): - if os.path.isfile(path): - os.remove(path) - else: - shutil.rmtree(path) - - -def _create_path_if_not_exists(path): - if not os.path.exists(path): - if os.path.isfile(path): - os.remove(path) - else: - os.mkdir(path) - - -def _get_weights_file(url, hash_md5=None, savepath=DEFAULT_CACHE_DIR): - """ - get checkpoint weight from giving url. - - Args: - url(string): checkpoint tar.gz url path. - hash_md5(string): checkpoint file md5. - savepath(string): checkpoint download save path. - - Returns: - string. - """ - - def reporthook(a, b, c): - percent = a * b * 100.0 / c - show_str = ('[%%-%ds]' % 70) % (int(percent * 80) * '#') - print("\rDownloading:", show_str, " %5.1f%%" % (percent), end="") - - def md5sum(file_name, hash_md5): - fp = open(file_name, 'rb') - content = fp.read() - fp.close() - m = hashlib.md5() - m.update(content.encode('utf-8')) - download_md5 = m.hexdigest() - return download_md5 == hash_md5 - - _remove_path_if_exists(os.path.realpath(savepath)) - _create_path_if_not_exists(os.path.realpath(savepath)) - ckpt_name = os.path.basename(url.split("/")[-1]) - # identify file exist or not - file_path = os.path.join(savepath, ckpt_name) - if os.path.isfile(file_path): - if hash_md5 and md5sum(file_path, hash_md5): - print('File already exists!') - return file_path - - file_path_ = file_path[:-7] if ".tar.gz" in file_path else file_path - _remove_path_if_exists(file_path_) - - # download the checkpoint file - print('Downloading data from url {}'.format(url)) - try: - urlretrieve(url, file_path, reporthook=reporthook) - except HTTPError as e: - raise Exception(e.code, e.msg, url) - except URLError as e: - raise Exception(e.errno, e.reason, url) - print('\nDownload finished!') - - # untar file_path - _unpacking_targz(file_path, os.path.realpath(savepath)) - - filesize = os.path.getsize(file_path) - # turn the file size to Mb format - print('File size = %.2f Mb' % (filesize / 1024 / 1024)) - return file_path_ - - -def _get_url_paths(url, ext='.tar.gz'): - response = requests.get(url) - if response.ok: - response_text = response.text - else: - return response.raise_for_status() - soup = BeautifulSoup(response_text, 'html.parser') - parent = [url + node.get('href') for node in soup.find_all('a') - if node.get('href').endswith(ext)] - return parent - - -def _get_file_from_url(base_url, base_name): - idx = 0 - urls = _get_url_paths(base_url + "/") - files = [url.split('/')[-1] for url in urls] - for i, name in enumerate(files): - if re.match(base_name + '*', name) is not None: - idx = i - break - return urls[idx] - - -def load_weights(network, network_name=None, force_reload=True, **kwargs): - r""" - Load a model from mindspore, with pretrained weights. - - Args: - network (Cell): Cell network. - network_name (string, optional): Cell network name get from network. Default: None. - force_reload (bool, optional): Whether to force a fresh download unconditionally. Default: False. - kwargs (dict, optional): The corresponding kwargs for download for model. - - - device_target (str, optional): Runtime device target. Default: 'ascend'. - - dataset (str, optional): Dataset to train the network. Default: 'cifar10'. - - version (str, optional): MindSpore version to save the checkpoint. Default: Latest version. - - Example: - >>> hub.load(network, network_name='lenet', - **{'device_target': 'ascend', 'dataset':'mnist', 'version': '0.5.0'}) - """ - if not isinstance(network, nn.Cell): - logger.error("Failed to combine the net and the parameters.") - msg = ("Argument net should be a Cell, but got {}.".format(type(network))) - raise TypeError(msg) - - if network_name is None: - if hasattr(network, network_name): - network_name = network.network_name - else: - msg = "Should input network name, but got None." - raise TypeError(msg) - - device_target = kwargs['device_target'] if kwargs['device_target'] else 'ascend' - dataset = kwargs['dataset'] if kwargs['dataset'] else 'imagenet' - version = kwargs['version'] if kwargs['version'] else mindspore.version.__version__ - - if network_name.split("_")[0] in MODEL_TARGET_CV: - model_type = "cv" - elif network_name.split("_")[0] in MODEL_TARGET_NLP: - model_type = "nlp" - else: - raise ValueError("Unsupported network {} download checkpoint.".format(network_name.split("_")[0])) - - download_base_url = "/".join([DOWNLOAD_BASIC_URL, - OFFICIAL_NAME, model_type, network_name]) - download_file_name = "_".join( - [network_name, device_target, version, dataset, OFFICIAL_NAME]) - download_url = _get_file_from_url(download_base_url, download_file_name) - - if force_reload: - ckpt_path = _get_weights_file(download_url, None, DEFAULT_CACHE_DIR) - else: - raise ValueError("Unsupported not force reload.") - - ckpt_file = os.path.join(ckpt_path, network_name + ".ckpt") - param_dict = load_checkpoint(ckpt_file) - load_param_into_net(network, param_dict)