You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 4.1 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import os
  3. import os.path as osp
  4. import warnings
  5. import numpy as np
  6. import onnx
  7. import onnxruntime as ort
  8. import torch
  9. import torch.nn as nn
  10. ort_custom_op_path = ''
  11. try:
  12. from mmcv.ops import get_onnxruntime_op_path
  13. ort_custom_op_path = get_onnxruntime_op_path()
  14. except (ImportError, ModuleNotFoundError):
  15. warnings.warn('If input model has custom op from mmcv, \
  16. you may have to build mmcv with ONNXRuntime from source.')
  17. class WrapFunction(nn.Module):
  18. """Wrap the function to be tested for torch.onnx.export tracking."""
  19. def __init__(self, wrapped_function):
  20. super(WrapFunction, self).__init__()
  21. self.wrapped_function = wrapped_function
  22. def forward(self, *args, **kwargs):
  23. return self.wrapped_function(*args, **kwargs)
  24. def ort_validate(model, feats, onnx_io='tmp.onnx'):
  25. """Validate the output of the onnxruntime backend is the same as the output
  26. generated by torch.
  27. Args:
  28. model (nn.Module | function): the function of model or model
  29. to be verified.
  30. feats (tuple(list(torch.Tensor)) | list(torch.Tensor) | torch.Tensor):
  31. the input of model.
  32. onnx_io (str): the name of onnx output file.
  33. """
  34. # if model is not an instance of nn.Module, then it is a normal
  35. # function and it should be wrapped.
  36. if isinstance(model, nn.Module):
  37. wrap_model = model
  38. else:
  39. wrap_model = WrapFunction(model)
  40. wrap_model.cpu().eval()
  41. with torch.no_grad():
  42. torch.onnx.export(
  43. wrap_model,
  44. feats,
  45. onnx_io,
  46. export_params=True,
  47. keep_initializers_as_inputs=True,
  48. do_constant_folding=True,
  49. verbose=False,
  50. opset_version=11)
  51. if isinstance(feats, tuple):
  52. ort_feats = []
  53. for feat in feats:
  54. ort_feats += feat
  55. else:
  56. ort_feats = feats
  57. # default model name: tmp.onnx
  58. onnx_outputs = get_ort_model_output(ort_feats)
  59. # remove temp file
  60. if osp.exists(onnx_io):
  61. os.remove(onnx_io)
  62. if isinstance(feats, tuple):
  63. torch_outputs = convert_result_list(wrap_model.forward(*feats))
  64. else:
  65. torch_outputs = convert_result_list(wrap_model.forward(feats))
  66. torch_outputs = [
  67. torch_output.detach().numpy() for torch_output in torch_outputs
  68. ]
  69. # match torch_outputs and onnx_outputs
  70. for i in range(len(onnx_outputs)):
  71. np.testing.assert_allclose(
  72. torch_outputs[i], onnx_outputs[i], rtol=1e-03, atol=1e-05)
  73. def get_ort_model_output(feat, onnx_io='tmp.onnx'):
  74. """Run the model in onnxruntime env.
  75. Args:
  76. feat (list[Tensor]): A list of tensors from torch.rand,
  77. each is a 4D-tensor.
  78. Returns:
  79. list[np.array]: onnxruntime infer result, each is a np.array
  80. """
  81. onnx_model = onnx.load(onnx_io)
  82. onnx.checker.check_model(onnx_model)
  83. session_options = ort.SessionOptions()
  84. # register custom op for onnxruntime
  85. if osp.exists(ort_custom_op_path):
  86. session_options.register_custom_ops_library(ort_custom_op_path)
  87. sess = ort.InferenceSession(onnx_io, session_options)
  88. if isinstance(feat, torch.Tensor):
  89. onnx_outputs = sess.run(None,
  90. {sess.get_inputs()[0].name: feat.numpy()})
  91. else:
  92. onnx_outputs = sess.run(None, {
  93. sess.get_inputs()[i].name: feat[i].numpy()
  94. for i in range(len(feat))
  95. })
  96. return onnx_outputs
  97. def convert_result_list(outputs):
  98. """Convert the torch forward outputs containing tuple or list to a list
  99. only containing torch.Tensor.
  100. Args:
  101. output (list(Tensor) | tuple(list(Tensor) | ...): the outputs
  102. in torch env, maybe containing nested structures such as list
  103. or tuple.
  104. Returns:
  105. list(Tensor): a list only containing torch.Tensor
  106. """
  107. # recursive end condition
  108. if isinstance(outputs, torch.Tensor):
  109. return [outputs]
  110. ret = []
  111. for sub in outputs:
  112. ret += convert_result_list(sub)
  113. return ret

No Description

Contributors (1)