You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convert_format.py 1.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from copy import deepcopy
  9. from .. import functional as F
  10. from ..module import Module
  11. from ..tensor import Tensor
  12. from ..core import _config
  13. def _is_nchw_format(param: Tensor):
  14. # TODO: use better condition
  15. return (len(param.shape) == 4 or len(param.shape) == 5) and param.format != "nhwc"
  16. def convert_tensor_format(x: Tensor, inplace: bool = True):
  17. """Convert NCHW Tensor to NHWC Tensor."""
  18. if x.ndim == 4:
  19. pattern = (0, 2, 3, 1)
  20. elif x.ndim == 5:
  21. pattern = (0, 1, 3, 4, 2)
  22. else:
  23. raise ValueError("Unsupport tensor ndim {}".format(x.ndim))
  24. # TODO: use initialization from tensor after fixing format setting
  25. if x.format != "nhwc":
  26. if inplace:
  27. data = x.numpy().transpose(*pattern)
  28. x[...] = Tensor(data, format="nhwc")
  29. else:
  30. x = Tensor(x.numpy().transpose(*pattern), format="nhwc")
  31. return x
  32. def convert_module_format(module: Module, inplace: bool = True):
  33. """Convert NCHW Module to NHWC Module."""
  34. if not inplace:
  35. module = deepcopy(module)
  36. for name, param in module.named_tensors():
  37. if _is_nchw_format(param):
  38. # hostvalue should still be valid, so no d2h cost.
  39. convert_tensor_format(param, inplace=True)
  40. return module