You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_checkpoint.py 1.7 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. import unittest
  3. from collections import OrderedDict
  4. import torch
  5. from torch import nn
  6. from detectron2.checkpoint.c2_model_loading import align_and_update_state_dicts
  7. from detectron2.utils.logger import setup_logger
  8. class TestCheckpointer(unittest.TestCase):
  9. def setUp(self):
  10. setup_logger()
  11. def create_complex_model(self):
  12. m = nn.Module()
  13. m.block1 = nn.Module()
  14. m.block1.layer1 = nn.Linear(2, 3)
  15. m.layer2 = nn.Linear(3, 2)
  16. m.res = nn.Module()
  17. m.res.layer2 = nn.Linear(3, 2)
  18. state_dict = OrderedDict()
  19. state_dict["layer1.weight"] = torch.rand(3, 2)
  20. state_dict["layer1.bias"] = torch.rand(3)
  21. state_dict["layer2.weight"] = torch.rand(2, 3)
  22. state_dict["layer2.bias"] = torch.rand(2)
  23. state_dict["res.layer2.weight"] = torch.rand(2, 3)
  24. state_dict["res.layer2.bias"] = torch.rand(2)
  25. return m, state_dict
  26. def test_complex_model_loaded(self):
  27. for add_data_parallel in [False, True]:
  28. model, state_dict = self.create_complex_model()
  29. if add_data_parallel:
  30. model = nn.DataParallel(model)
  31. model_sd = model.state_dict()
  32. align_and_update_state_dicts(model_sd, state_dict)
  33. for loaded, stored in zip(model_sd.values(), state_dict.values()):
  34. # different tensor references
  35. self.assertFalse(id(loaded) == id(stored))
  36. # same content
  37. self.assertTrue(loaded.equal(stored))
  38. if __name__ == "__main__":
  39. unittest.main()

No Description