|
|
|
@@ -14,9 +14,21 @@ |
|
|
|
# ============================================================================ |
|
|
|
"""hub config.""" |
|
|
|
from src.resnet_imgnet import resnet50 |
|
|
|
from mindspore import Tensor |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
def create_network(name, *args, **kwargs): |
|
|
|
if name == 'resnet-0.65x': |
|
|
|
return resnet50(*args, **kwargs) |
|
|
|
def get_index(filename): |
|
|
|
index = [] |
|
|
|
with open(filename) as fr: |
|
|
|
for line in fr: |
|
|
|
ind = Tensor((np.array(line.strip('\n').split(' ')[:-1])).astype(np.int32).reshape(-1, 1)) |
|
|
|
index.append(ind) |
|
|
|
return index |
|
|
|
|
|
|
|
|
|
|
|
def create_network(name, rate=0.65, index_filename='index.txt', **kwargs): |
|
|
|
index = get_index(index_filename) |
|
|
|
if name == 'resnet50-0.65x': |
|
|
|
return resnet50(rate=rate, index=index, **kwargs) |
|
|
|
raise NotImplementedError(f"{name} is not implemented in the repo") |