Browse Source

!3757 debug mindspore hub

Merge pull request !3757 from chenzhongming/r0.6
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
983437feaf
1 changed files with 18 additions and 18 deletions
  1. +18
    -18
      mindspore/hub.py

+ 18
- 18
mindspore/hub.py View File

@@ -32,13 +32,12 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net

DOWNLOAD_BASIC_URL = "http://download.mindspore.cn/model_zoo"
OFFICIAL_NAME = "official"
DEFAULT_CACHE_DIR = '~/.cache'
MODEL_TARGET_CV = ['alexnet', 'fasterrcnn', 'googlenet',
'lenet', 'resnet', 'ssd', 'vgg', 'yolo']
DEFAULT_CACHE_DIR = '.cache'
MODEL_TARGET_CV = ['alexnet', 'fasterrcnn', 'googlenet', 'lenet', 'resnet', 'resnet50', 'ssd', 'vgg', 'yolo']
MODEL_TARGET_NLP = ['bert', 'mass', 'transformer']


def _packing_targz(output_filename, savepath="./"):
def _packing_targz(output_filename, savepath=DEFAULT_CACHE_DIR):
"""
Packing the input filename to filename.tar.gz in source dir.
"""
@@ -49,7 +48,7 @@ def _packing_targz(output_filename, savepath="./"):
raise OSError("Cannot tar file {} for - {}".format(output_filename, e))


def _unpacking_targz(input_filename, savepath="./"):
def _unpacking_targz(input_filename, savepath=DEFAULT_CACHE_DIR):
"""
Unpacking the input filename to dirs.
"""
@@ -69,14 +68,14 @@ def _remove_path_if_exists(path):


def _create_path_if_not_exists(path):
if os.path.exists(path):
if not os.path.exists(path):
if os.path.isfile(path):
os.remove(path)
else:
os.mkdir(path)


def _get_weights_file(url, hash_md5=None, savepath='./'):
def _get_weights_file(url, hash_md5=None, savepath=DEFAULT_CACHE_DIR):
"""
get checkpoint weight from giving url.

@@ -103,7 +102,8 @@ def _get_weights_file(url, hash_md5=None, savepath='./'):
download_md5 = m.hexdigest()
return download_md5 == hash_md5

_create_path_if_not_exists(savepath)
_remove_path_if_exists(os.path.realpath(savepath))
_create_path_if_not_exists(os.path.realpath(savepath))
ckpt_name = os.path.basename(url.split("/")[-1])
# identify file exist or not
file_path = os.path.join(savepath, ckpt_name)
@@ -112,8 +112,8 @@ def _get_weights_file(url, hash_md5=None, savepath='./'):
print('File already exists!')
return file_path

file_path = file_path[:-7] if ".tar.gz" in file_path else file_path
_remove_path_if_exists(file_path)
file_path_ = file_path[:-7] if ".tar.gz" in file_path else file_path
_remove_path_if_exists(file_path_)

# download the checkpoint file
print('Downloading data from url {}'.format(url))
@@ -126,14 +126,12 @@ def _get_weights_file(url, hash_md5=None, savepath='./'):
print('\nDownload finished!')

# untar file_path
_unpacking_targz(file_path)
_unpacking_targz(file_path, os.path.realpath(savepath))

# # get the file size
file_path = os.path.join(savepath, ckpt_name)
filesize = os.path.getsize(file_path)
# turn the file size to Mb format
print('File size = %.2f Mb' % (filesize / 1024 / 1024))
return file_path
return file_path_


def _get_url_paths(url, ext='.tar.gz'):
@@ -150,7 +148,7 @@ def _get_url_paths(url, ext='.tar.gz'):

def _get_file_from_url(base_url, base_name):
idx = 0
urls = _get_url_paths(base_url)
urls = _get_url_paths(base_url + "/")
files = [url.split('/')[-1] for url in urls]
for i, name in enumerate(files):
if re.match(base_name + '*', name) is not None:
@@ -172,8 +170,8 @@ def load_weights(network, network_name=None, force_reload=True, **kwargs):
dataset (string, optional): Dataset to train the network. Default: 'cifar10'.

Example:
>>> mindspore.hub.load(network, network_name='lenet',
**{'device_target': 'ascend', 'dataset':'cifar10', 'version': 'beta0.5'})
>>> hub.load_weights(network, network_name='lenet',
**{'device_target': 'ascend', 'dataset':'mnist', 'version': '0.5.0'})
"""
if not isinstance(network, nn.Cell):
logger.error("Failed to combine the net and the parameters.")
@@ -195,9 +193,11 @@ def load_weights(network, network_name=None, force_reload=True, **kwargs):
model_type = "cv"
elif network_name.split("_")[0] in MODEL_TARGET_NLP:
model_type = "nlp"
else:
raise ValueError("Unsupported network {} download checkpoint.".format(network_name.split("_")[0]))

download_base_url = "/".join([DOWNLOAD_BASIC_URL,
OFFICIAL_NAME, model_type])
OFFICIAL_NAME, model_type, network_name])
download_file_name = "_".join(
[network_name, device_target, version, dataset, OFFICIAL_NAME])
download_url = _get_file_from_url(download_base_url, download_file_name)


Loading…
Cancel
Save