You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_convert_format.py 2.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import numpy as np
  9. import pytest
  10. import megengine.functional as F
  11. import megengine.module as M
  12. from megengine import Parameter, Tensor, amp, config
  13. class MyModule(M.Module):
  14. class InnerModule(M.Module):
  15. def __init__(self):
  16. super().__init__()
  17. self.bn = M.BatchNorm2d(4)
  18. def forward(self, x):
  19. return self.bn(x)
  20. def __init__(self):
  21. super().__init__()
  22. self.i = self.InnerModule()
  23. self.conv = M.Conv2d(4, 4, 4, groups=2)
  24. self.bn = M.BatchNorm2d(4)
  25. self.param = Parameter(np.ones((1, 3, 1, 1), dtype=np.float32))
  26. self.buff = Tensor(np.ones((1, 3, 1, 1), dtype=np.float32))
  27. def forward(self, x):
  28. x = self.i(x)
  29. x = self.bn(x)
  30. return x
  31. @pytest.mark.parametrize("is_inplace", [False, True])
  32. def test_convert_module(is_inplace):
  33. m = MyModule()
  34. expected_shape = {
  35. "i.bn.weight": (1, 1, 1, 4),
  36. "i.bn.bias": (1, 1, 1, 4),
  37. "i.bn.running_mean": (1, 1, 1, 4),
  38. "i.bn.running_var": (1, 1, 1, 4),
  39. "conv.weight": (2, 2, 4, 4, 2),
  40. "conv.bias": (1, 1, 1, 4),
  41. "bn.weight": (1, 1, 1, 4),
  42. "bn.bias": (1, 1, 1, 4),
  43. "bn.running_mean": (1, 1, 1, 4),
  44. "bn.running_var": (1, 1, 1, 4),
  45. "param": (1, 1, 1, 3),
  46. "buff": (1, 1, 1, 3),
  47. }
  48. m = amp.convert_module_format(m, is_inplace)
  49. for name, param in m.named_tensors():
  50. assert param.format == "nhwc"
  51. with config._override(auto_format_convert=False):
  52. assert param.shape == expected_shape[name], name