You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convert_format.py 1.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from copy import deepcopy
  9. from .. import functional as F
  10. from ..module import Module
  11. from ..tensor import Tensor
  12. from ..core import _config
  13. def _is_nchw_format(param: Tensor):
  14. # TODO: use better condition
  15. return (len(param.shape) == 4 or len(param.shape) == 5) and param.format != "nhwc"
  16. def convert_tensor_format(x: Tensor, inplace: bool = True):
  17. """Convert NCHW Tensor to NHWC Tensor."""
  18. if x.ndim == 4:
  19. pattern = (0, 2, 3, 1)
  20. elif x.ndim == 5:
  21. pattern = (0, 1, 3, 4, 2)
  22. else:
  23. raise ValueError("Unsupport tensor ndim {}".format(x.ndim))
  24. # TODO: use initialization from tensor after fixing format setting
  25. if x.format != "nhwc":
  26. if inplace:
  27. # reset will destroy backward grad
  28. data = x.numpy().transpose(*pattern)
  29. x[...] = Tensor(data, format="nhwc")
  30. else:
  31. # use mge interface to maintain grad
  32. x = F.transpose(x, pattern)
  33. x.format="nhwc"
  34. return x
  35. def convert_module_format(module: Module, inplace: bool = True):
  36. """Convert NCHW Module to NHWC Module."""
  37. if not inplace:
  38. module = deepcopy(module)
  39. for name, param in module.named_tensors():
  40. if _is_nchw_format(param):
  41. # hostvalue should still be valid, so no d2h cost.
  42. convert_tensor_format(param, inplace=True)
  43. return module