From a059e8910f14fb0edd45faebabd8594833a6f3ee Mon Sep 17 00:00:00 2001 From: chenzomi Date: Thu, 30 Jul 2020 22:18:38 +0800 Subject: [PATCH] debug mindspore hub --- mindspore/hub.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/mindspore/hub.py b/mindspore/hub.py index 72013c8218..550c9ea719 100644 --- a/mindspore/hub.py +++ b/mindspore/hub.py @@ -32,13 +32,12 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net DOWNLOAD_BASIC_URL = "http://download.mindspore.cn/model_zoo" OFFICIAL_NAME = "official" -DEFAULT_CACHE_DIR = '~/.cache' -MODEL_TARGET_CV = ['alexnet', 'fasterrcnn', 'googlenet', - 'lenet', 'resnet', 'ssd', 'vgg', 'yolo'] +DEFAULT_CACHE_DIR = '.cache' +MODEL_TARGET_CV = ['alexnet', 'fasterrcnn', 'googlenet', 'lenet', 'resnet', 'resnet50', 'ssd', 'vgg', 'yolo'] MODEL_TARGET_NLP = ['bert', 'mass', 'transformer'] -def _packing_targz(output_filename, savepath="./"): +def _packing_targz(output_filename, savepath=DEFAULT_CACHE_DIR): """ Packing the input filename to filename.tar.gz in source dir. """ @@ -49,7 +48,7 @@ def _packing_targz(output_filename, savepath="./"): raise OSError("Cannot tar file {} for - {}".format(output_filename, e)) -def _unpacking_targz(input_filename, savepath="./"): +def _unpacking_targz(input_filename, savepath=DEFAULT_CACHE_DIR): """ Unpacking the input filename to dirs. """ @@ -69,14 +68,14 @@ def _remove_path_if_exists(path): def _create_path_if_not_exists(path): - if os.path.exists(path): + if not os.path.exists(path): if os.path.isfile(path): os.remove(path) else: os.mkdir(path) -def _get_weights_file(url, hash_md5=None, savepath='./'): +def _get_weights_file(url, hash_md5=None, savepath=DEFAULT_CACHE_DIR): """ get checkpoint weight from giving url. @@ -103,7 +102,8 @@ def _get_weights_file(url, hash_md5=None, savepath='./'): download_md5 = m.hexdigest() return download_md5 == hash_md5 - _create_path_if_not_exists(savepath) + _remove_path_if_exists(os.path.realpath(savepath)) + _create_path_if_not_exists(os.path.realpath(savepath)) ckpt_name = os.path.basename(url.split("/")[-1]) # identify file exist or not file_path = os.path.join(savepath, ckpt_name) @@ -112,8 +112,8 @@ def _get_weights_file(url, hash_md5=None, savepath='./'): print('File already exists!') return file_path - file_path = file_path[:-7] if ".tar.gz" in file_path else file_path - _remove_path_if_exists(file_path) + file_path_ = file_path[:-7] if ".tar.gz" in file_path else file_path + _remove_path_if_exists(file_path_) # download the checkpoint file print('Downloading data from url {}'.format(url)) @@ -126,14 +126,12 @@ def _get_weights_file(url, hash_md5=None, savepath='./'): print('\nDownload finished!') # untar file_path - _unpacking_targz(file_path) + _unpacking_targz(file_path, os.path.realpath(savepath)) - # # get the file size - file_path = os.path.join(savepath, ckpt_name) filesize = os.path.getsize(file_path) # turn the file size to Mb format print('File size = %.2f Mb' % (filesize / 1024 / 1024)) - return file_path + return file_path_ def _get_url_paths(url, ext='.tar.gz'): @@ -150,7 +148,7 @@ def _get_url_paths(url, ext='.tar.gz'): def _get_file_from_url(base_url, base_name): idx = 0 - urls = _get_url_paths(base_url) + urls = _get_url_paths(base_url + "/") files = [url.split('/')[-1] for url in urls] for i, name in enumerate(files): if re.match(base_name + '*', name) is not None: @@ -172,8 +170,8 @@ def load_weights(network, network_name=None, force_reload=True, **kwargs): dataset (string, optional): Dataset to train the network. Default: 'cifar10'. Example: - >>> mindspore.hub.load(network, network_name='lenet', - **{'device_target': 'ascend', 'dataset':'cifar10', 'version': 'beta0.5'}) + >>> hub.load_weights(network, network_name='lenet', + **{'device_target': 'ascend', 'dataset':'mnist', 'version': '0.5.0'}) """ if not isinstance(network, nn.Cell): logger.error("Failed to combine the net and the parameters.") @@ -195,9 +193,11 @@ def load_weights(network, network_name=None, force_reload=True, **kwargs): model_type = "cv" elif network_name.split("_")[0] in MODEL_TARGET_NLP: model_type = "nlp" + else: + raise ValueError("Unsupported network {} download checkpoint.".format(network_name.split("_")[0])) download_base_url = "/".join([DOWNLOAD_BASIC_URL, - OFFICIAL_NAME, model_type]) + OFFICIAL_NAME, model_type, network_name]) download_file_name = "_".join( [network_name, device_target, version, dataset, OFFICIAL_NAME]) download_url = _get_file_from_url(download_base_url, download_file_name)