|
|
@@ -12,51 +12,43 @@ |
|
|
# See the License for the specific language governing permissions and |
|
|
# See the License for the specific language governing permissions and |
|
|
# limitations under the License. |
|
|
# limitations under the License. |
|
|
# ============================================================================ |
|
|
# ============================================================================ |
|
|
"""Convert ckpt to air.""" |
|
|
|
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
import argparse |
|
|
import numpy as np |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
from mindspore import context |
|
|
|
|
|
from mindspore import Tensor |
|
|
|
|
|
from mindspore.train.serialization import export, load_checkpoint, load_param_into_net |
|
|
|
|
|
|
|
|
import mindspore |
|
|
|
|
|
from mindspore import context, Tensor |
|
|
|
|
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export |
|
|
|
|
|
|
|
|
from src.centerface import CenterfaceMobilev2 |
|
|
from src.centerface import CenterfaceMobilev2 |
|
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) |
|
|
|
|
|
|
|
|
|
|
|
def save_air(): |
|
|
|
|
|
"""Save air file""" |
|
|
|
|
|
print('============= centerface start save air ==================') |
|
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='Convert ckpt to air') |
|
|
|
|
|
parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load') |
|
|
|
|
|
parser.add_argument('--batch_size', type=int, default=8, help='batch size') |
|
|
|
|
|
|
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
network = CenterfaceMobilev2() |
|
|
|
|
|
|
|
|
|
|
|
if os.path.isfile(args.pretrained): |
|
|
|
|
|
param_dict = load_checkpoint(args.pretrained) |
|
|
|
|
|
param_dict_new = {} |
|
|
|
|
|
for key, values in param_dict.items(): |
|
|
|
|
|
if key.startswith('moments.') or key.startswith('moment1.') or key.startswith('moment2.'): |
|
|
|
|
|
continue |
|
|
|
|
|
elif key.startswith('centerface_network.'): |
|
|
|
|
|
param_dict_new[key[19:]] = values |
|
|
|
|
|
else: |
|
|
|
|
|
param_dict_new[key] = values |
|
|
|
|
|
load_param_into_net(network, param_dict_new) |
|
|
|
|
|
print('load model {} success'.format(args.pretrained)) |
|
|
|
|
|
|
|
|
|
|
|
input_data = np.random.uniform(low=0, high=1.0, size=(args.batch_size, 3, 832, 832)).astype(np.float32) |
|
|
|
|
|
|
|
|
|
|
|
tensor_input_data = Tensor(input_data) |
|
|
|
|
|
export(network, tensor_input_data, |
|
|
|
|
|
file_name=args.pretrained.replace('.ckpt', '_' + str(args.batch_size) + 'b.air'), file_format='AIR') |
|
|
|
|
|
|
|
|
|
|
|
print("export model success.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
save_air() |
|
|
|
|
|
|
|
|
from src.config import ConfigCenterface |
|
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='centerface export') |
|
|
|
|
|
parser.add_argument("--device_id", type=int, default=0, help="Device id") |
|
|
|
|
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size") |
|
|
|
|
|
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.") |
|
|
|
|
|
parser.add_argument("--file_name", type=str, default="centerface.air", help="output file name.") |
|
|
|
|
|
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
config = ConfigCenterface() |
|
|
|
|
|
net = CenterfaceMobilev2() |
|
|
|
|
|
|
|
|
|
|
|
param_dict = load_checkpoint(args.ckpt_file) |
|
|
|
|
|
param_dict_new = {} |
|
|
|
|
|
for key, values in param_dict.items(): |
|
|
|
|
|
if key.startswith('moments.') or key.startswith('moment1.') or key.startswith('moment2.'): |
|
|
|
|
|
continue |
|
|
|
|
|
elif key.startswith('centerface_network.'): |
|
|
|
|
|
param_dict_new[key[19:]] = values |
|
|
|
|
|
else: |
|
|
|
|
|
param_dict_new[key] = values |
|
|
|
|
|
|
|
|
|
|
|
load_param_into_net(net, param_dict_new) |
|
|
|
|
|
net.set_train(False) |
|
|
|
|
|
|
|
|
|
|
|
input_data = Tensor(np.zeros([args.batch_size, 3, config.input_h, config.input_w]), mindspore.float32) |
|
|
|
|
|
export(net, input_data, file_name=args.file_name, file_format=args.file_format) |