Browse Source

!8831 modify export script for centerface and yolov4

From: @yuzhenhua666
Reviewed-by: @c_34,@yingjy
Signed-off-by: @c_34
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
ce503eb3bc
2 changed files with 36 additions and 44 deletions
  1. +35
    -43
      model_zoo/official/cv/centerface/export.py
  2. +1
    -1
      model_zoo/official/cv/yolov4/export.py

+ 35
- 43
model_zoo/official/cv/centerface/export.py View File

@@ -12,51 +12,43 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Convert ckpt to air."""
import os

import argparse import argparse
import numpy as np import numpy as np


from mindspore import context
from mindspore import Tensor
from mindspore.train.serialization import export, load_checkpoint, load_param_into_net
import mindspore
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export


from src.centerface import CenterfaceMobilev2 from src.centerface import CenterfaceMobilev2

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)

def save_air():
"""Save air file"""
print('============= centerface start save air ==================')

parser = argparse.ArgumentParser(description='Convert ckpt to air')
parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
parser.add_argument('--batch_size', type=int, default=8, help='batch size')

args = parser.parse_args()
network = CenterfaceMobilev2()

if os.path.isfile(args.pretrained):
param_dict = load_checkpoint(args.pretrained)
param_dict_new = {}
for key, values in param_dict.items():
if key.startswith('moments.') or key.startswith('moment1.') or key.startswith('moment2.'):
continue
elif key.startswith('centerface_network.'):
param_dict_new[key[19:]] = values
else:
param_dict_new[key] = values
load_param_into_net(network, param_dict_new)
print('load model {} success'.format(args.pretrained))

input_data = np.random.uniform(low=0, high=1.0, size=(args.batch_size, 3, 832, 832)).astype(np.float32)

tensor_input_data = Tensor(input_data)
export(network, tensor_input_data,
file_name=args.pretrained.replace('.ckpt', '_' + str(args.batch_size) + 'b.air'), file_format='AIR')

print("export model success.")


if __name__ == "__main__":
save_air()
from src.config import ConfigCenterface

parser = argparse.ArgumentParser(description='centerface export')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--file_name", type=str, default="centerface.air", help="output file name.")
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format')
args = parser.parse_args()

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)

if __name__ == '__main__':
config = ConfigCenterface()
net = CenterfaceMobilev2()

param_dict = load_checkpoint(args.ckpt_file)
param_dict_new = {}
for key, values in param_dict.items():
if key.startswith('moments.') or key.startswith('moment1.') or key.startswith('moment2.'):
continue
elif key.startswith('centerface_network.'):
param_dict_new[key[19:]] = values
else:
param_dict_new[key] = values

load_param_into_net(net, param_dict_new)
net.set_train(False)

input_data = Tensor(np.zeros([args.batch_size, 3, config.input_h, config.input_w]), mindspore.float32)
export(net, input_data, file_name=args.file_name, file_format=args.file_format)

+ 1
- 1
model_zoo/official/cv/yolov4/export.py View File

@@ -26,7 +26,7 @@ parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--batch_size", type=int, default=1, help="batch size") parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--testing_shape", type=int, default=608, help="test shape") parser.add_argument("--testing_shape", type=int, default=608, help="test shape")
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.") parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--file_name", type=str, default="ssd.air", help="output file name.")
parser.add_argument("--file_name", type=str, default="yolov4.air", help="output file name.")
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format') parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format')
args = parser.parse_args() args = parser.parse_args()




Loading…
Cancel
Save