|
|
|
@@ -32,13 +32,12 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net |
|
|
|
|
|
|
|
DOWNLOAD_BASIC_URL = "http://download.mindspore.cn/model_zoo" |
|
|
|
OFFICIAL_NAME = "official" |
|
|
|
DEFAULT_CACHE_DIR = '~/.cache' |
|
|
|
MODEL_TARGET_CV = ['alexnet', 'fasterrcnn', 'googlenet', |
|
|
|
'lenet', 'resnet', 'ssd', 'vgg', 'yolo'] |
|
|
|
DEFAULT_CACHE_DIR = '.cache' |
|
|
|
MODEL_TARGET_CV = ['alexnet', 'fasterrcnn', 'googlenet', 'lenet', 'resnet', 'resnet50', 'ssd', 'vgg', 'yolo'] |
|
|
|
MODEL_TARGET_NLP = ['bert', 'mass', 'transformer'] |
|
|
|
|
|
|
|
|
|
|
|
def _packing_targz(output_filename, savepath="./"): |
|
|
|
def _packing_targz(output_filename, savepath=DEFAULT_CACHE_DIR): |
|
|
|
""" |
|
|
|
Packing the input filename to filename.tar.gz in source dir. |
|
|
|
""" |
|
|
|
@@ -49,7 +48,7 @@ def _packing_targz(output_filename, savepath="./"): |
|
|
|
raise OSError("Cannot tar file {} for - {}".format(output_filename, e)) |
|
|
|
|
|
|
|
|
|
|
|
def _unpacking_targz(input_filename, savepath="./"): |
|
|
|
def _unpacking_targz(input_filename, savepath=DEFAULT_CACHE_DIR): |
|
|
|
""" |
|
|
|
Unpacking the input filename to dirs. |
|
|
|
""" |
|
|
|
@@ -69,14 +68,14 @@ def _remove_path_if_exists(path): |
|
|
|
|
|
|
|
|
|
|
|
def _create_path_if_not_exists(path): |
|
|
|
if os.path.exists(path): |
|
|
|
if not os.path.exists(path): |
|
|
|
if os.path.isfile(path): |
|
|
|
os.remove(path) |
|
|
|
else: |
|
|
|
os.mkdir(path) |
|
|
|
|
|
|
|
|
|
|
|
def _get_weights_file(url, hash_md5=None, savepath='./'): |
|
|
|
def _get_weights_file(url, hash_md5=None, savepath=DEFAULT_CACHE_DIR): |
|
|
|
""" |
|
|
|
get checkpoint weight from giving url. |
|
|
|
|
|
|
|
@@ -103,7 +102,8 @@ def _get_weights_file(url, hash_md5=None, savepath='./'): |
|
|
|
download_md5 = m.hexdigest() |
|
|
|
return download_md5 == hash_md5 |
|
|
|
|
|
|
|
_create_path_if_not_exists(savepath) |
|
|
|
_remove_path_if_exists(os.path.realpath(savepath)) |
|
|
|
_create_path_if_not_exists(os.path.realpath(savepath)) |
|
|
|
ckpt_name = os.path.basename(url.split("/")[-1]) |
|
|
|
# identify file exist or not |
|
|
|
file_path = os.path.join(savepath, ckpt_name) |
|
|
|
@@ -112,8 +112,8 @@ def _get_weights_file(url, hash_md5=None, savepath='./'): |
|
|
|
print('File already exists!') |
|
|
|
return file_path |
|
|
|
|
|
|
|
file_path = file_path[:-7] if ".tar.gz" in file_path else file_path |
|
|
|
_remove_path_if_exists(file_path) |
|
|
|
file_path_ = file_path[:-7] if ".tar.gz" in file_path else file_path |
|
|
|
_remove_path_if_exists(file_path_) |
|
|
|
|
|
|
|
# download the checkpoint file |
|
|
|
print('Downloading data from url {}'.format(url)) |
|
|
|
@@ -126,14 +126,12 @@ def _get_weights_file(url, hash_md5=None, savepath='./'): |
|
|
|
print('\nDownload finished!') |
|
|
|
|
|
|
|
# untar file_path |
|
|
|
_unpacking_targz(file_path) |
|
|
|
_unpacking_targz(file_path, os.path.realpath(savepath)) |
|
|
|
|
|
|
|
# # get the file size |
|
|
|
file_path = os.path.join(savepath, ckpt_name) |
|
|
|
filesize = os.path.getsize(file_path) |
|
|
|
# turn the file size to Mb format |
|
|
|
print('File size = %.2f Mb' % (filesize / 1024 / 1024)) |
|
|
|
return file_path |
|
|
|
return file_path_ |
|
|
|
|
|
|
|
|
|
|
|
def _get_url_paths(url, ext='.tar.gz'): |
|
|
|
@@ -150,7 +148,7 @@ def _get_url_paths(url, ext='.tar.gz'): |
|
|
|
|
|
|
|
def _get_file_from_url(base_url, base_name): |
|
|
|
idx = 0 |
|
|
|
urls = _get_url_paths(base_url) |
|
|
|
urls = _get_url_paths(base_url + "/") |
|
|
|
files = [url.split('/')[-1] for url in urls] |
|
|
|
for i, name in enumerate(files): |
|
|
|
if re.match(base_name + '*', name) is not None: |
|
|
|
@@ -172,8 +170,8 @@ def load_weights(network, network_name=None, force_reload=True, **kwargs): |
|
|
|
dataset (string, optional): Dataset to train the network. Default: 'cifar10'. |
|
|
|
|
|
|
|
Example: |
|
|
|
>>> mindspore.hub.load(network, network_name='lenet', |
|
|
|
**{'device_target': 'ascend', 'dataset':'cifar10', 'version': 'beta0.5'}) |
|
|
|
>>> hub.load_weights(network, network_name='lenet', |
|
|
|
**{'device_target': 'ascend', 'dataset':'mnist', 'version': '0.5.0'}) |
|
|
|
""" |
|
|
|
if not isinstance(network, nn.Cell): |
|
|
|
logger.error("Failed to combine the net and the parameters.") |
|
|
|
@@ -195,9 +193,11 @@ def load_weights(network, network_name=None, force_reload=True, **kwargs): |
|
|
|
model_type = "cv" |
|
|
|
elif network_name.split("_")[0] in MODEL_TARGET_NLP: |
|
|
|
model_type = "nlp" |
|
|
|
else: |
|
|
|
raise ValueError("Unsupported network {} download checkpoint.".format(network_name.split("_")[0])) |
|
|
|
|
|
|
|
download_base_url = "/".join([DOWNLOAD_BASIC_URL, |
|
|
|
OFFICIAL_NAME, model_type]) |
|
|
|
OFFICIAL_NAME, model_type, network_name]) |
|
|
|
download_file_name = "_".join( |
|
|
|
[network_name, device_target, version, dataset, OFFICIAL_NAME]) |
|
|
|
download_url = _get_file_from_url(download_base_url, download_file_name) |
|
|
|
|